Monitoring progress with callbacks in JaxPlan.

Monitoring progress with callbacks in JaxPlan.#

In many applications, it is desirable to call an optimizer iteratively, monitoring its performance in real time and adjusting as needed. In this notebook, we illustrate how to do this with JaxPlan by showing how to monitor and plot the train and test loss curves across iterations.

Start by installing the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import numpy as np
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
from IPython.display import Image

import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will use the Quadcopter control problem as an example here:

env = pyRDDLGym.make('Quadcopter', '0', vectorized=True)

Generate a configuration file to load desired hyper-parameters for the planner as usual, and instantiate the planner instance:

config = """
[Model]
[Optimizer]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.06}
batch_size_train=1
batch_size_test=1
pgpe=None
[Training]
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
[INFO] JAX gradient compiler will cast p-vars {'ID'} to float.
[INFO] Bounds of action-fluent <power> set to (array([[-10000., -10000., -10000., -10000.]], dtype=float32), array([[10000., 10000., 10000., 10000.]], dtype=float32)).

In the current example, we will not instantiate a controller agent, because we will be instructing the planner to optimize ourselves. To do this, we note the optimize_generator() function in the planner API, which builds a generator that we can iterate through to unravel each step of the optimization sequentially:

sequence_of_steps = planner.optimize_generator(epochs=8000, print_summary=False, **train_args)

This does not actually begin optimization, because we have not iterated on the resulting iterator. If we do this and inspect the result it yields, we will get a dictionary of information summarizing the iteration. Let’s do this now for one iteration and see what kind of information we can extract:

next(sequence_of_steps).keys()
[WARN] policy_hyperparams is not set, setting 1.0 for all action-fluents which could be suboptimal.
      0 it /   -19874.05078 train /   -19867.47266 test /   -19867.47266 best / 0 status /      0 pgpe:   5%| | 00:05 ,
[FAIL] Compiler encountered the following error(s) in the training model:
    Casting occurred that could result in loss of precision.
dict_keys(['status', 'iteration', 'train_return', 'test_return', 'best_return', 'pgpe_return', 'params', 'best_params', 'pgpe_params', 'last_iteration_improved', 'pgpe_improved', 'grad', 'best_grad', 'updates', 'elapsed_time', 'key', 'model_params', 'progress', 'train_log', 'error', 'fluents', 'invariant', 'precondition', 'reward', 'termination'])

As you can see, we can extract a lot of infomation about the optimization, including convergence status, return information, parameters of the plan, gradient information, and even rollouts of the fluents as JAX arrays.

We are interested in monitoring the train_return and test_return of each iteration, so let’s exhaust the iterator and cache the values at each iteration:

train_returns, test_returns = [], []
for iteration_result in sequence_of_steps:
    train_returns.append(float(iteration_result['train_return'].item()))
    test_returns.append(float(iteration_result['test_return'].item()))
   7999 it /    -6860.17676 train /    -6862.88281 test /    -6862.88281 best / 6 status /      0 pgpe: 100%|█| 01:29 ,


Finally, let’s plot these time series to assess convergence of the planner:

%matplotlib inline
plt.plot(range(len(train_returns)), np.clip(train_returns, -20000, np.inf), label='train')
plt.plot(range(len(test_returns)), np.clip(test_returns, -20000, np.inf), label='test')
plt.legend()
plt.show()
../_images/b9dfb7fb3c2c3f1b9cf8ab3e89b9fbc21d9181bcf78f98c044b3633929ebe6ad.png

As you can see, the planner has converged to an optimal solution. Let’s plot the behavior of the plan to make sure:

if not os.path.exists('frames'):
    os.makedirs('frames')
recorder = MovieGenerator("frames", "quadcopter", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)

agent = JaxOfflineController(planner, params=iteration_result['best_params'], **train_args)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/quadcopter_0.gif') 
../_images/e0a8ad24c016fe77842acd09a5ae20337af62160cbd5a2c896397af13012c142.gif
<Figure size 640x480 with 0 Axes>