Risk-aware planning with RAPTOR in JaxPlan.

Risk-aware planning with RAPTOR in JaxPlan.#

This variation of the closed-loop planning notebook optimizes a nonlinear risk-aware utility function.

%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

Let’s optimize the power generation domain from the IPPC 2023:

env = pyRDDLGym.make('PowerGen_Continuous', '0', vectorized=True)

Let’s generate a risk-neutral baseline from the policy network:

config = """
[Model]
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [64, 64]}
optimizer_kwargs={'learning_rate': 0.0002}
batch_size_train=256
[Training]
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=45, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(100)]
   1191 it /  -19847.048828 train /      24.565781 test /      50.516998 best / 5 status: : 1191it [00:44, 26.68it/s]                                  

Next, let’s repeat the example, but this time we will use the conditional value at risk to optimize the lower 10 percent of the returns. This should produce a policy that is more robust against power shortages:

config = """
[Model]
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [64, 64]}
optimizer_kwargs={'learning_rate': 0.0002}
utility='cvar'
utility_kwargs={'alpha': 0.1}
batch_size_train=256
[Training]
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=45, **train_args)
risk_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(100)]
   1155 it /  -20441.250000 train /   -1248.467163 test /    -996.323242 best / 5 status: : 1155it [00:44, 25.78it/s]                                  

Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:

%matplotlib inline
ax = sns.violinplot(data=[drp_returns, risk_returns], orient='h')
ax.set_yticklabels(['Risk-Neutral', 'Risk-Averse'])
plt.show()
../_images/9abea4759f48b3b39309a299239507ee4363922960c19cbceb30493b76c2fa36.png

As you can see, the returns are more stable with the utility objective.