Risk-aware planning with RAPTOR in JaxPlan.

Risk-aware planning with RAPTOR in JaxPlan.#

This variation of the closed-loop planning notebook optimizes a nonlinear risk-aware utility function.

%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController, load_config_from_string

Let’s optimize the power generation domain from the IPPC 2023:

env = pyRDDLGym.make('PowerGen_Continuous', '0', vectorized=True)

Let’s generate a risk-neutral baseline from the policy network:

config = """
[Compiler]
[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128]}
optimizer_kwargs={'learning_rate': 0.001}
batch_size_train=512
batch_size_test=512
pgpe=None
[Optimize]
key=42
epochs=1000
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[INFO] Bounds of action-fluent <curProd> set to (array([0., 0., 0.], dtype=float32), array([10., 10., 10.], dtype=float32)).
 999 it |  -10032.40137 train |   -1192.09583 test |     -27.26541 best | 6 status: 100%|██████████| 00:31 , 31.10it/s

Next, let’s repeat the example, but this time we will use the conditional value at risk to optimize the lower 10 percent of the returns. This should produce a policy that is more robust against power shortages:

config = """
[Compiler]
[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128]}
optimizer_kwargs={'learning_rate': 0.001}
utility='cvar'
utility_kwargs={'alpha': 0.15}
batch_size_train=512
batch_size_test=512
pgpe=None
[Optimize]
key=42
epochs=1000
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, **train_args)
risk_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[INFO] Bounds of action-fluent <curProd> set to (array([0., 0., 0.], dtype=float32), array([10., 10., 10.], dtype=float32)).
 999 it |    -986.52704 train |   -1658.43726 test |    -687.04211 best | 6 status: 100%|██████████| 00:32 , 30.14it/s

Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:

%matplotlib inline
ax = sns.violinplot(data=[drp_returns, risk_returns], orient='h')
ax.set_yticklabels(['Risk-Neutral', 'Risk-Averse'])
plt.show()
../_images/5a76f752a2ffeeaa2e76696edf04494372e0cfb7162907f60848d7050c69f352.png

As you can see, the returns are more stable with the utility objective.