Risk-aware planning with RAPTOR in JaxPlan.

Risk-aware planning with RAPTOR in JaxPlan.#

This variation of the closed-loop planning notebook optimizes a nonlinear risk-aware utility function.

%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController, load_config_from_string

Let’s optimize the power generation domain from the IPPC 2023:

env = pyRDDLGym.make('PowerGen_Continuous', '0', vectorized=True)

Let’s generate a risk-neutral baseline from the policy network:

config = """
[Compiler]
[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [64, 64]}
optimizer_kwargs={'learning_rate': 0.0002}
batch_size_train=128
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=40, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[INFO] Bounds of action-fluent <curProd> set to (array([0., 0., 0.], dtype=float32), array([10., 10., 10.], dtype=float32)).
 3117 it |     269.57965 train |     179.82907 test |     277.70148 best | 5 status: 100%|██████████| 00:39 , 77.94it/s

Next, let’s repeat the example, but this time we will use the conditional value at risk to optimize the lower 10 percent of the returns. This should produce a policy that is more robust against power shortages:

config = """
[Compiler]
[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [64, 64]}
optimizer_kwargs={'learning_rate': 0.0002}
utility='cvar'
utility_kwargs={'alpha': 0.1}
batch_size_train=128
pgpe=None
[Optimize]
key=42
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=40, **train_args)
risk_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(200)]
[INFO] Bounds of action-fluent <curProd> set to (array([0., 0., 0.], dtype=float32), array([10., 10., 10.], dtype=float32)).
 3172 it |     164.15268 train |     148.94827 test |     148.94827 best | 5 status: 100%|██████████| 00:39 , 79.31it/s

Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:

%matplotlib inline
ax = sns.violinplot(data=[drp_returns, risk_returns], orient='h')
ax.set_yticklabels(['Risk-Neutral', 'Risk-Averse'])
plt.show()
../_images/888f1cfdd1ffbc0f5c7fb17f82a954669632149f975137da35ae2c158cc02943.png

As you can see, the returns are more stable with the utility objective.