Closed-loop planning with deep reactive policies in JaxPlan.#
This rudimentary example compares the performance of a closed-loop and open-loop controller in JaxPlan on a stochastic domain. The closed loop controller learns a policy network that takes the state of the system as input, and produces an action-fluent as output.
%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController, load_config_from_string
We will optimize the stochastic Wildfire problem from IPPC 2014. Note again the use of the vectorized option:
env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)
Let’s generate a baseline using the straightline planner, and see if we can improve upon it:
config = """
[Compiler]
sigmoid_weight=100.0
[Planner]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.1}
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
slp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[90m[INFO] Compiler will cast pvars {"out-of-fuel'", 'TARGET', 'burning', 'NEIGHBOR', "burning'", 'out-of-fuel', 'cut-out', 'put-out'} to float.[0m
[90m[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.[0m
[90m[INFO] Bounds of action-fluent <put-out> set to (None, None).[0m
[90m[INFO] Bounds of action-fluent <cut-out> set to (None, None).[0m
[90m[INFO] Number of boolean actions 18 > cardinality 1: enabling projected gradient to satisfy constraints on action-fluents.[0m
18 it | -9112.12500 train | -6684.57812 test | -6671.89062 best | 0 status: 18%|█▊ | 00:02 , 5.29it/s
[31m[FAIL] Training model error: Casting occurred that could result in loss of precision.[0m
2303 it | -2905.71069 train | -950.93750 test | -753.51562 best | 5 status: 100%|██████████| 00:19 , 115.18it/s
Let’s now generate the policy network. Note the slight difference in the config file arguments:
config = """
[Compiler]
sigmoid_weight=100.0
[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 64]}
optimizer_kwargs={'learning_rate': 0.001}
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[90m[INFO] Compiler will cast pvars {"out-of-fuel'", 'TARGET', 'burning', 'NEIGHBOR', "burning'", 'out-of-fuel', 'cut-out', 'put-out'} to float.[0m
[90m[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.[0m
[90m[INFO] Bounds of action-fluent <put-out> set to (None, None).[0m
[90m[INFO] Bounds of action-fluent <cut-out> set to (None, None).[0m
8 it | -8476.03711 train | -5925.85059 test | -5925.85059 best | 0 status: 18%|█▊ | 00:02 , 2.46it/s
[31m[FAIL] Training model error: Casting occurred that could result in loss of precision.[0m
1104 it | -1888.28052 train | -570.95312 test | -200.17188 best | 5 status: 100%|██████████| 00:19 , 55.22it/s
Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:
%matplotlib inline
ax = sns.violinplot(data=[slp_returns, drp_returns], orient='h')
ax.set_yticklabels(['SLP', 'DRP'])
plt.show()
As you can see, the reactive policy network has a lower spread in return, and has a much higher probability of generating the optimal return of zero than the straight line planner.