Tuning policy hyper-parameters in JaxPlan

Tuning policy hyper-parameters in JaxPlan#

In this advanced notebook, we illustrate how to tune hyper-parameters of the policy and planner efficiently using Bayesian optimization, such as the topology of the policy network, the learning rate, and the model relaxations.

Start by installing the required packages. Notice we use the extra argument to install the required bayesian-optimization package:

%pip install --quiet --upgrade pip
%pip install --quiet pandas pyRDDLGym rddlrepository pyRDDLGym-jax[extra]
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import os
import pandas as pd
from IPython.display import Image
from IPython.utils import io
import matplotlib.pyplot as plt
import math

import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string
from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
C:\Python\envs\rddl\Lib\site-packages\pyRDDLGym\core\debug\exception.py:28: UserWarning: cv2 is not installed: save_as_mp4 option will be disabled.
  warnings.warn(message)

We will use the MountainCar control problem as our example:

env = pyRDDLGym.make('MountainCar_Continuous_gym', '0', vectorized=True)

In order to instruct the tuner to tune specific parameters, we need to create a config file with placeholders for the tunable parameters. In this case, we let TUNE_LR denote the learning rate, TUNE_L1 and TUNE_L2 denote the layer sizes of the policy, and TUNE_WEIGHT to denote the model relaxation parameter:

config = """
[Model]
comparison_kwargs={'weight': TUNE_WEIGHT}
rounding_kwargs={'weight': TUNE_WEIGHT}
control_kwargs={'weight': TUNE_WEIGHT}
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [TUNE_L1, TUNE_L2]}
optimizer_kwargs={'learning_rate': TUNE_LR}
batch_size_train=1
batch_size_test=1
[Training]
train_seconds=15
print_summary=False
print_progress=False
"""

Next, we need to connect the selected hyperparameters to ranges the Bayesian optimizer searches over, as well as a mapping from the BO parameter to the RDDL equivalent:

hyperparams = [
    Hyperparameter('TUNE_WEIGHT', 0., 15., math.exp2),
    Hyperparameter('TUNE_LR', -15., 0., math.exp2),
    Hyperparameter('TUNE_L1', 2, 64, math.floor),
    Hyperparameter('TUNE_L2', 2, 64, math.floor)
]    

Next, we will instantiate the tuning instance for the policy network (DRP). This will tune the topology of the network, learning rate and model hyper-parameters. We specify number of independent runs to average per trial, number of parallel threads to 4, and number of iterations to 10, leading to a total of 40 individual trials/hyper-parameters to test:

tuning = JaxParameterTuning(env=env, config_template=config, hyperparams=hyperparams, online=False,
                            eval_trials=1, num_workers=4, gp_iters=10)

Finally, we launch the tuning instance, where we specify the RNG key, the file where logs will be saved, and that we wish to save a plot of the hyper-parameters tried in a 2D space. Warning: this will take a while:

with io.capture_output():
    best_params = tuning.tune(key=42, log_file='mountaincar_tuning.csv')
print(best_params)
{'TUNE_WEIGHT': 10.104708040704434, 'TUNE_LR': 0.00035041832766298753, 'TUNE_L1': 61, 'TUNE_L2': 12}

The outputs of the bayesian optimization have now been saved in a csv file. Let’s use pandas to read this file:

gp_data = pd.read_csv('mountaincar_tuning.csv')
gp_data.head(5)
pid worker iteration target best_target acq_params TUNE_WEIGHT TUNE_LR TUNE_L1 TUNE_L2
0 31252 0 0 -20.000000 91.871464 {'random_state': RandomState(MT19937) at 0x233... 504.908188 0.061636 25 60
1 19792 1 0 -0.000060 91.871464 {'random_state': RandomState(MT19937) at 0x233... 8150.326768 0.000056 11 11
2 30408 2 0 -0.000059 91.871464 {'random_state': RandomState(MT19937) at 0x233... 23965.107358 0.000038 39 45
3 28576 3 0 91.871464 91.871464 {'random_state': RandomState(MT19937) at 0x233... 6.732152 0.000202 53 15
4 13564 0 1 88.120588 91.871464 {'random_state': RandomState(MT19937) at 0x233... 7.676418 0.001000 62 11

Let’s plot the target values across iterations:

%matplotlib inline
plt.plot(gp_data['target'])
plt.show()
../_images/b5bed83c49954ceed581e02a8dec6ae1207bfd111709f0dc6a603986ac7a7a64.png

The config file corresponding to the best hyper-parameters is:

print(tuning.best_config)
[Model]
comparison_kwargs={'weight': 10.104708040704434}
rounding_kwargs={'weight': 10.104708040704434}
control_kwargs={'weight': 10.104708040704434}
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [61, 12]}
optimizer_kwargs={'learning_rate': 0.00035041832766298753}
batch_size_train=1
batch_size_test=1
[Training]
train_seconds=15
print_summary=False
print_progress=False

Let’s load this config into a planner and evaluate it:

planner_args, _, train_args = load_config_from_string(tuning.best_config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, **train_args)

if not os.path.exists('frames'):
    os.makedirs('frames')
recorder = MovieGenerator("frames", "mountaincar", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/mountaincar_0.gif') 
../_images/6b43d10888892b57489d9a7ffb16735d7bfdc16fdd46b38dafc370205ffbd835.gif