Closed-loop planning with deep reactive policies in JaxPlan.

Closed-loop planning with deep reactive policies in JaxPlan.#

This rudimentary example compares the performance of a closed-loop and open-loop controller in JaxPlan on a stochastic domain. The closed loop controller learns a policy network that takes the state of the system as input, and produces an action-fluent as output.

%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns

import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

We will optimize the stochastic Wildfire problem from IPPC 2014. Note again the use of the vectorized option:

env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)

Let’s generate a baseline using the straightline planner, and see if we can improve upon it:

config = """
[Model]
comparison_kwargs={'weight': 100}
rounding_kwargs={'weight': 100}
control_kwargs={'weight': 100}
[Optimizer]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.1}
[Training]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
slp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(100)]
   1605 it /    -642.354492 train /    -749.515625 test /    -527.468750 best / 5 status: : 1605it [00:19, 80.88it/s]                                  

Let’s now generate the policy network. Note the slight difference in the config file arguments:

config = """
[Model]
comparison_kwargs={'weight': 100}
rounding_kwargs={'weight': 100}
control_kwargs={'weight': 100}
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 64]}
optimizer_kwargs={'learning_rate': 0.001}
[Training]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(100)]
    915 it /    -561.330505 train /    -467.140625 test /    -291.250000 best / 5 status: : 915it [00:19, 46.25it/s]                                   

Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:

%matplotlib inline
ax = sns.violinplot(data=[slp_returns, drp_returns], orient='h')
ax.set_yticklabels(['SLP', 'DRP'])
plt.show()
../_images/4dc9deaa7fa1575f8c57effa62980f2e5800b714a696191a78e58b50bdf016f9.png

As you can see, the reactive policy network has a lower spread in return, and has a much higher probability of generating the optimal return of zero than the straight line planner.