Closed-loop planning with deep reactive policies in JaxPlan.

Closed-loop planning with deep reactive policies in JaxPlan.#

This rudimentary example compares the performance of a closed-loop and open-loop controller in JaxPlan on a stochastic domain. The closed loop controller learns a policy network that takes the state of the system as input, and produces an action-fluent as output.

%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will optimize the stochastic Wildfire problem from IPPC 2014. Note again the use of the vectorized option:

env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)

Let’s generate a baseline using the straightline planner, and see if we can improve upon it:

config = """
[Compiler]
sigmoid_weight=100.0
[Planner]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.1}
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
slp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[INFO] Compiler will cast pvars {"out-of-fuel'", 'TARGET', 'burning', 'NEIGHBOR', "burning'", 'out-of-fuel', 'cut-out', 'put-out'} to float.
[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.
[INFO] Bounds of action-fluent <put-out> set to (None, None).
[INFO] Bounds of action-fluent <cut-out> set to (None, None).
[INFO] Number of boolean actions 18 > cardinality 1: enabling projected gradient to satisfy constraints on action-fluents.
 18 it |   -9112.12500 train |   -6684.57812 test |   -6671.89062 best | 0 status:  18%|█▊        | 00:02 , 5.29it/s
[FAIL] Training model error: Casting occurred that could result in loss of precision.
 2303 it |   -2905.71069 train |    -950.93750 test |    -753.51562 best | 5 status: 100%|██████████| 00:19 , 115.18it/s

Let’s now generate the policy network. Note the slight difference in the config file arguments:

config = """
[Compiler]
sigmoid_weight=100.0
[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 64]}
optimizer_kwargs={'learning_rate': 0.001}
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[INFO] Compiler will cast pvars {"out-of-fuel'", 'TARGET', 'burning', 'NEIGHBOR', "burning'", 'out-of-fuel', 'cut-out', 'put-out'} to float.
[INFO] Compiler will cast CPFs {"out-of-fuel'", "burning'"} to float.
[INFO] Bounds of action-fluent <put-out> set to (None, None).
[INFO] Bounds of action-fluent <cut-out> set to (None, None).
 8 it |   -8476.03711 train |   -5925.85059 test |   -5925.85059 best | 0 status:  18%|█▊        | 00:02 , 2.46it/s
[FAIL] Training model error: Casting occurred that could result in loss of precision.
 1104 it |   -1888.28052 train |    -570.95312 test |    -200.17188 best | 5 status: 100%|██████████| 00:19 , 55.22it/s

Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:

%matplotlib inline
ax = sns.violinplot(data=[slp_returns, drp_returns], orient='h')
ax.set_yticklabels(['SLP', 'DRP'])
plt.show()
../_images/8839f759a272d5df508a6e9173d383e484e52ff7198e494c4def52f7f88fb282.png

As you can see, the reactive policy network has a lower spread in return, and has a much higher probability of generating the optimal return of zero than the straight line planner.