Open-loop planning with straightline plans in JaxPlan.#
This rudimentary example illustrates how to set up an offline JaxPlan agent that first optimizes the problem and then evaluates.
First install and import the required packages:
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import os
from IPython.display import Image
import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController, load_config_from_string
We will optimize the reacher multi-jointed arm control problem. Please note that the environment state and action fluents should be vectorized tensors for JaxPlan, so we must enable the vectorized option:
env = pyRDDLGym.make('Reacher_physics', '0', vectorized=True)
Now we will construct a straight-line offline planner as our policy, using the default parameters. This instantiates an independent trainable parameter vector for each action-fluent per decision time step. Note that the controller is an instance of the pyRDDLGym BasePolicy, so it will support normal interaction with pyRDDLGym environments out of the box:
planner = JaxBackpropPlanner(rddl=env.model, plan=JaxStraightLinePlan())
agent = JaxOfflineController(planner, print_summary=False, train_seconds=30)
agent.evaluate(env, episodes=1)
[90m[INFO] Compiler will cast pvars {'TARGET-SEGMENT', 'POS'} to float.[0m
[90m[INFO] Bounds of action-fluent <torque> set to (array([-0.1, -0.1], dtype=float32), array([0.1, 0.1], dtype=float32)).[0m
8488 it | -199.55951 train | -195.31267 test | -180.96204 [31 it] pgpe | -99.08913 best | 5 status: 100%|██████████| 00:29 , 282.96it/s
{'mean': np.float64(-122.93819352552316),
'median': np.float64(-122.93819352552316),
'min': np.float64(-122.93819352552316),
'max': np.float64(-122.93819352552316),
'std': np.float64(0.0)}
We can do better! Since JaxPlan can be sensitive to hyper-parameter choices, Let’s repeat the previous exercise but with tweaked hyper-parameters. Specifically, let’s decrease the learning rate, change the optimizer to ADAM, and remove mini-batching since this environment is deterministic. This is much easier to provide in a configuration file as follows:
config = """
[Compiler]
[Planner]
method='JaxStraightLinePlan'
optimizer='adam'
optimizer_kwargs={'learning_rate': 0.0001}
batch_size_train=1
batch_size_test=1
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=30, **train_args)
agent.evaluate(env, episodes=1)
[90m[INFO] Compiler will cast pvars {'TARGET-SEGMENT', 'POS'} to float.[0m
[90m[INFO] Bounds of action-fluent <torque> set to (array([-0.1, -0.1], dtype=float32), array([0.1, 0.1], dtype=float32)).[0m
34067 it | -6.78525 train | -6.78528 test | -6.78528 best | 5 status: 100%|██████████| 00:30 , 1135.57it/s
{'mean': np.float64(-6.785245208773351),
'median': np.float64(-6.785245208773351),
'min': np.float64(-6.785245208773351),
'max': np.float64(-6.785245208773351),
'std': np.float64(0.0)}
Let’s visualize our trained agent’s behavior:
if not os.path.exists('frames'):
os.makedirs('frames')
recorder = MovieGenerator("frames", "reacher", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/reacher_0.gif')