Saving and loading trained policies in JaxPlan.

Saving and loading trained policies in JaxPlan.#

In this notebook, we illustrate the procedure of saving and loading trained JaxPlan policies.

Start by installing the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
^C
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController

We will load the Wildfire example to illustrate the process:

env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)

Let’s now train a fresh policy network to solve this problem, but we also save the weights by specifying a save_path:

planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy(), 
                             pgpe=None, optimizer_kwargs={'learning_rate': 0.001})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=30, save_path='wildfire_drp.pickle')
agent.evaluate(env, episodes=100)
[INFO] Compiler will cast pvars {'NEIGHBOR', 'TARGET', 'cut-out', 'out-of-fuel', "burning'", "out-of-fuel'", 'put-out', 'burning'} to float.
[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.
[INFO] Bounds of action-fluent <put-out> set to (None, None).
[INFO] Bounds of action-fluent <cut-out> set to (None, None).
 10 it |   -7721.44385 train |   -5475.62500 test |   -5475.62500 best | 0 status:  13%|█▎        | 00:02 , 2.90it/s
[FAIL] Training model error: Casting occurred that could result in loss of precision.
 2163 it |   -2049.79468 train |    -386.90625 test |    -194.15625 best | 5 status: 100%|██████████| 00:29 , 72.12it/s
{'mean': np.float64(-350.65),
 'median': np.float64(-40.0),
 'min': np.float64(-6760.0),
 'max': np.float64(-40.0),
 'std': np.float64(1248.6233729591966)}

To load the weights automatically, create a controller with params pointing to the pickle file:

new_planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy())
new_agent = JaxOfflineController(new_planner, print_summary=False, params='wildfire_drp.pickle')
[INFO] Compiler will cast pvars {'NEIGHBOR', 'TARGET', 'cut-out', 'out-of-fuel', "burning'", "out-of-fuel'", 'put-out', 'burning'} to float.
[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.
[INFO] Bounds of action-fluent <put-out> set to (None, None).
[INFO] Bounds of action-fluent <cut-out> set to (None, None).

Note that in this case there is no pre-training of the policy. Let’s evaluate the agent to make sure it still performs the same as the trained one:

new_agent.evaluate(env, episodes=100)
{'mean': np.float64(-338.7),
 'median': np.float64(-40.0),
 'min': np.float64(-7150.0),
 'max': np.float64(-40.0),
 'std': np.float64(1218.7261013041445)}

Indeed, the performance is quite similar.