Monitoring progress with callbacks in JaxPlan.

Monitoring progress with callbacks in JaxPlan.#

In many applications, it is desirable to call an optimizer iteratively, monitoring its performance in real time and adjusting as needed. In this notebook, we illustrate how to do this with JaxPlan by showing how to monitor and plot the train and test loss curves across iterations.

Start by installing the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
from IPython.display import Image

import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will use the Quadcopter control problem as an example here:

env = pyRDDLGym.make('Quadcopter', '0', vectorized=True)

Generate a configuration file to load desired hyper-parameters for the planner as usual, and instantiate the planner instance:

config = """
[Model]
[Optimizer]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.08}
batch_size_train=1
batch_size_test=1
[Training]
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)

In the current example, we will not instantiate a controller agent, because we will be instructing the planner to optimize ourselves. To do this, we note the optimize_generator() function in the planner API, which builds a generator that we can iterate through to unravel each step of the optimization sequentially:

sequence_of_steps = planner.optimize_generator(epochs=8000, print_summary=False, **train_args)

This does not actually begin optimization, because we have not iterated on the resulting iterator. If we do this and inspect the result it yields, we will get a dictionary of information summarizing the iteration. Let’s do this now for one iteration and see what kind of information we can extract:

next(sequence_of_steps).keys()
      0 it /  -19874.021484 train /  -19865.257812 test /  -19865.257812 best / 0 status:   5%|█▏                      | 5/100 [00:05<01:53,  1.19s/it]
dict_keys(['status', 'iteration', 'train_return', 'test_return', 'best_return', 'params', 'best_params', 'last_iteration_improved', 'grad', 'best_grad', 'updates', 'elapsed_time', 'key', 'model_params', 'progress', 'train_log', 'error', 'fluents', 'invariant', 'precondition', 'reward', 'termination'])

As you can see, we can extract a lot of infomation about the optimization, including convergence status, return information, parameters of the plan, gradient information, and even rollouts of the fluents as JAX arrays.

We are interested in monitoring the train_return and test_return of each iteration, so let’s exhaust the iterator and cache the values at each iteration:

train_returns, test_returns = [], []
for iteration_result in sequence_of_steps:
    train_returns.append(float(iteration_result['train_return'].item()))
    test_returns.append(float(iteration_result['test_return'].item()))
   7999 it /   -5938.086426 train /   -5942.592285 test /   -5798.791016 best / 6 status: : 7999it [01:38, 81.01it/s]                                  

Finally, let’s plot these time series to assess convergence of the planner:

%matplotlib inline
import numpy as np
plt.plot(range(len(train_returns)), np.clip(train_returns, -20000, np.inf), label='train')
plt.plot(range(len(test_returns)), test_returns, label='test')
plt.legend()
plt.show()
../_images/f00fd773f5989a5735bb0d6f342c059ca420cc00284e20c00a12470acd568e35.png

As you can see, the planner has converged to an optimal solution. Let’s plot the behavior of the plan to make sure:

if not os.path.exists('frames'):
    os.makedirs('frames')
recorder = MovieGenerator("frames", "quadcopter", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)

agent = JaxOfflineController(planner, params=iteration_result['best_params'], **train_args)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/quadcopter_0.gif') 
../_images/b67024b2f3bd7e96c1b03765baf4896fd6cef0532e49d3e24925434324fa54d5.gif
<Figure size 640x480 with 0 Axes>