Saving and loading trained policies in JaxPlan.#
In this notebook, we illustrate the procedure of saving and loading trained JaxPlan policies.
Start by installing the required packages:
%pip install --quiet --upgrade pip
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import pickle
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string
We will load the Wildfire example to illustrate the process:
env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)
Let’s now train a fresh policy network to solve this problem:
planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy(), optimizer_kwargs={'learning_rate': 0.01})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=10)
agent.evaluate(env, episodes=100)
282 it / -223.691406 train / -425.093750 test / -343.765625 best / 0 status: : 282it [00:09, 30.72it/s]
{'mean': -319.1,
'median': -40.0,
'min': -7525.0,
'max': -35.0,
'std': 1042.2375880767302}
To save the model, we will just pickle the final parameters of the policy network:
with open('wildfire_drp.pickle', 'wb') as file:
pickle.dump(agent.params, file)
Now, let’s load the pickled parameters and pass them to a newly-instantiated controller:
with open('wildfire_drp.pickle', 'rb') as file:
params = pickle.load(file)
new_planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy())
new_agent = JaxOfflineController(new_planner, params=params, print_summary=False)
Note that in this case there is no pre-training of the policy. Let’s evaluate the agent to make sure it still performs the same as the trained one:
new_agent.evaluate(env, episodes=100)
{'mean': -412.9,
'median': -35.0,
'min': -4050.0,
'max': -35.0,
'std': 1089.5086461336598}
Indeed, the performance is quite similar.