Saving and loading trained policies in JaxPlan.#
In this notebook, we illustrate the procedure of saving and loading trained JaxPlan policies.
Start by installing the required packages:
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
^C
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController
We will load the Wildfire example to illustrate the process:
env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)
Let’s now train a fresh policy network to solve this problem, but we also save the weights by specifying a save_path:
planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy(),
pgpe=None, optimizer_kwargs={'learning_rate': 0.001})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=30, save_path='wildfire_drp.pickle')
agent.evaluate(env, episodes=100)
[90m[INFO] Compiler will cast pvars {'NEIGHBOR', 'TARGET', 'cut-out', 'out-of-fuel', "burning'", "out-of-fuel'", 'put-out', 'burning'} to float.[0m
[90m[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.[0m
[90m[INFO] Bounds of action-fluent <put-out> set to (None, None).[0m
[90m[INFO] Bounds of action-fluent <cut-out> set to (None, None).[0m
10 it | -7721.44385 train | -5475.62500 test | -5475.62500 best | 0 status: 13%|█▎ | 00:02 , 2.90it/s
[31m[FAIL] Training model error: Casting occurred that could result in loss of precision.[0m
2163 it | -2049.79468 train | -386.90625 test | -194.15625 best | 5 status: 100%|██████████| 00:29 , 72.12it/s
{'mean': np.float64(-350.65),
'median': np.float64(-40.0),
'min': np.float64(-6760.0),
'max': np.float64(-40.0),
'std': np.float64(1248.6233729591966)}
To load the weights automatically, create a controller with params pointing to the pickle file:
new_planner = JaxBackpropPlanner(rddl=env.model, plan=JaxDeepReactivePolicy())
new_agent = JaxOfflineController(new_planner, print_summary=False, params='wildfire_drp.pickle')
[90m[INFO] Compiler will cast pvars {'NEIGHBOR', 'TARGET', 'cut-out', 'out-of-fuel', "burning'", "out-of-fuel'", 'put-out', 'burning'} to float.[0m
[90m[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.[0m
[90m[INFO] Bounds of action-fluent <put-out> set to (None, None).[0m
[90m[INFO] Bounds of action-fluent <cut-out> set to (None, None).[0m
Note that in this case there is no pre-training of the policy. Let’s evaluate the agent to make sure it still performs the same as the trained one:
new_agent.evaluate(env, episodes=100)
{'mean': np.float64(-338.7),
'median': np.float64(-40.0),
'min': np.float64(-7150.0),
'max': np.float64(-40.0),
'std': np.float64(1218.7261013041445)}
Indeed, the performance is quite similar.