Open-loop planning with straightline plans in JaxPlan.

Open-loop planning with straightline plans in JaxPlan.#

This rudimentary example illustrates how to set up an offline JaxPlan agent that first optimizes the problem and then evaluates.

First install and import the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import os
from IPython.display import Image

import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will optimize the reacher multi-jointed arm control problem. Please note that the environment state and action fluents should be vectorized tensors for JaxPlan, so we must enable the vectorized option:

env = pyRDDLGym.make('Reacher_gym', '0', vectorized=True)

Now we will construct a straight-line offline planner as our policy, using the default parameters. This instantiates an independent trainable parameter vector for each action-fluent per decision time step. Note that the controller is an instance of the pyRDDLGym BasePolicy, so it will support normal interaction with pyRDDLGym environments out of the box:

planner = JaxBackpropPlanner(rddl=env.model, plan=JaxStraightLinePlan())
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20)
agent.evaluate(env, episodes=1)
[INFO] JAX gradient compiler will cast p-vars {'TARGET-SEGMENT', 'POS'} to float.
[INFO] Bounds of action-fluent <torque> set to (array([-1., -1.], dtype=float32), array([1., 1.], dtype=float32)).
[WARN] policy_hyperparams is not set, setting 1.0 for all action-fluents which could be suboptimal.
      0 it /     -105.98688 train /     -242.29878 test /     -242.29878 best / 0 status /      1 pgpe:  10%|█         | 00:01 , 0.50it/s
[WARN] policy_hyperparams is not set, setting 1.0 for all action-fluents which could be suboptimal.
   2778 it /     -187.42670 train /     -200.35764 test /     -138.74423 best / 5 status /     76 pgpe: 100%|██████████| 00:19 , 138.94it/s


{'mean': np.float64(-138.60032065066403),
 'median': np.float64(-138.60032065066403),
 'min': np.float64(-138.60032065066403),
 'max': np.float64(-138.60032065066403),
 'std': np.float64(0.0)}

We can do better! Since JaxPlan can be sensitive to hyper-parameter choices, Let’s repeat the previous exercise but with tweaked hyper-parameters. Specifically, let’s decrease the learning rate, change the optimizer to ADAM, and remove mini-batching since this environment is deterministic. This is much easier to provide in a configuration file as follows:

config = """
[Model]
[Optimizer]
method='JaxStraightLinePlan'
optimizer='adam'
optimizer_kwargs={'learning_rate': 0.0001}
batch_size_train=1
batch_size_test=1
pgpe=None
[Training]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
agent.evaluate(env, episodes=1)
[INFO] JAX gradient compiler will cast p-vars {'TARGET-SEGMENT', 'POS'} to float.
[INFO] Bounds of action-fluent <torque> set to (array([-1., -1.], dtype=float32), array([1., 1.], dtype=float32)).
[WARN] policy_hyperparams is not set, setting 1.0 for all action-fluents which could be suboptimal.
      0 it /     -184.75497 train /     -164.18906 test /     -164.18906 best / 0 status /      0 pgpe:   6%|▌         | 00:00 , 0.85it/s
[WARN] policy_hyperparams is not set, setting 1.0 for all action-fluents which could be suboptimal.
  11281 it /       -7.24862 train /       -7.31594 test /       -7.19512 best / 5 status /      0 pgpe: 100%|██████████| 00:20 , 564.08it/s

{'mean': np.float64(-7.195374767031867),
 'median': np.float64(-7.195374767031867),
 'min': np.float64(-7.195374767031867),
 'max': np.float64(-7.195374767031867),
 'std': np.float64(0.0)}

Let’s visualize our trained agent’s behavior:

if not os.path.exists('frames'):
    os.makedirs('frames')
recorder = MovieGenerator("frames", "reacher", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/reacher_0.gif') 
../_images/79a9cb5fef005ecf150518f096b0540ee601174a0db1b759a6abcc524b280dc1.gif