Risk-aware planning with RAPTOR in JaxPlan.#
This variation of the closed-loop planning notebook optimizes a nonlinear risk-aware utility function.
%pip install --quiet --upgrade pip
%pip install --quiet seaborn
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
import seaborn as sns
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxDeepReactivePolicy, JaxBackpropPlanner, JaxOfflineController, load_config_from_string
Let’s optimize the power generation domain from the IPPC 2023:
env = pyRDDLGym.make('PowerGen_Continuous', '0', vectorized=True)
Let’s generate a risk-neutral baseline from the policy network:
config = """
[Model]
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [64, 64]}
optimizer_kwargs={'learning_rate': 0.0002}
batch_size_train=256
[Training]
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=45, **train_args)
drp_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(100)]
1191 it / -19847.048828 train / 24.565781 test / 50.516998 best / 5 status: : 1191it [00:44, 26.68it/s]
Next, let’s repeat the example, but this time we will use the conditional value at risk to optimize the lower 10 percent of the returns. This should produce a policy that is more robust against power shortages:
config = """
[Model]
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [64, 64]}
optimizer_kwargs={'learning_rate': 0.0002}
utility='cvar'
utility_kwargs={'alpha': 0.1}
batch_size_train=256
[Training]
"""
planner_args, _, train_args = load_config_from_string(config)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOfflineController(planner, print_summary=False, train_seconds=45, **train_args)
risk_returns = [agent.evaluate(env, episodes=1)['mean'] for _ in range(100)]
1155 it / -20441.250000 train / -1248.467163 test / -996.323242 best / 5 status: : 1155it [00:44, 25.78it/s]
Finally, let’s plot two overlapping histograms comparing the distribution of returns for the two plans:
%matplotlib inline
ax = sns.violinplot(data=[drp_returns, risk_returns], orient='h')
ax.set_yticklabels(['Risk-Neutral', 'Risk-Averse'])
plt.show()
As you can see, the returns are more stable with the utility objective.