Saving and loading trained policies in JaxPlan.

Saving and loading trained policies in JaxPlan.#

In this notebook, we illustrate the procedure of saving and loading trained JaxPlan policies.

Start by installing the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import pickle

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will load the Wildfire example to illustrate the process:

env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)

Let’s now train a fresh policy network to solve this problem:

planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy(), optimizer_kwargs={'learning_rate': 0.01})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=10)
agent.evaluate(env, episodes=100)
    282 it /    -223.691406 train /    -425.093750 test /    -343.765625 best / 0 status: : 282it [00:09, 30.72it/s]   
{'mean': -319.1,
 'median': -40.0,
 'min': -7525.0,
 'max': -35.0,
 'std': 1042.2375880767302}

To save the model, we will just pickle the final parameters of the policy network:

with open('wildfire_drp.pickle', 'wb') as file:
    pickle.dump(agent.params, file)

Now, let’s load the pickled parameters and pass them to a newly-instantiated controller:

with open('wildfire_drp.pickle', 'rb') as file:
    params = pickle.load(file)
    
new_planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy())
new_agent = JaxOfflineController(new_planner, params=params, print_summary=False)

Note that in this case there is no pre-training of the policy. Let’s evaluate the agent to make sure it still performs the same as the trained one:

new_agent.evaluate(env, episodes=100)
{'mean': -412.9,
 'median': -35.0,
 'min': -4050.0,
 'max': -35.0,
 'std': 1089.5086461336598}

Indeed, the performance is quite similar.