Building Custom Policies with JaxPlan

Building Custom Policies with JaxPlan#

This advanced notebook illustrates how a custom policy implementation can be designed with JaxPlan.

First install and import the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import os
from IPython.display import Image
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import jax
import jax.numpy as jnp

import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController, JaxPlan, get_action_info

We will optimize the continuous reservoir control problem as an example.

env = pyRDDLGym.make('Reservoir_Continuous', '0', vectorized=True)

We will now build a custom piecewise linear policy in JaxPlan. That is, we will define a factored policy per reservoir that is piecewise linear in the rlevel(?r):

release(?r) = if (rlevel(?r) >= T) then V1(?r) else V2(?r)

During training, we will relax this policy such that it supports backpropagation to update the parameters (T, V1, V2):

release(?r) = sigmoid(w * (rlevel(?r) - T)) * (V1 - V2) + V2

where w is a hyperparameter.

class PWLinearPolicy(JaxPlan):
    history_dependent = False

    def __init__(self, action_sigmoid=10.0):
        self.action_sigmoid = action_sigmoid

    def compile(self, compiled, test_compiled, _bounds, horizon, preprocessor=None) -> None:

        # compute bounds on action fluents and project to valid range
        shapes, self.bounds, *_ = get_action_info(compiled, _bounds, horizon)
        def clip_fn(params):
            return {name: jnp.clip(value, *self.bounds['release']) for name, value in params.items()}
        def project_fn(planner_state):
            return clip_fn(planner_state.policy_params), True
        self.projection = project_fn

        # initialize T, V1, and V2 to valid action bounds and sigmoid weight for soft action
        def init_fn(sim_state):
            hyperparams = {'sigmoid': self.action_sigmoid}
            keys = jax.random.split(sim_state.key, num=3)
            T = 0.01 * jax.random.normal(keys[0], shapes['release'][1:])
            V1 = 0.01 * jax.random.normal(keys[1], shapes['release'][1:])
            V2 = 0.01 * jax.random.normal(keys[2], shapes['release'][1:])
            params = clip_fn({'T': T, 'V1': V1, 'V2': V2})
            return params, hyperparams
        self.initializer = init_fn

        # apply soft action with gradient propagation
        def apply_train_fn(sim_state, planner_state):
            fls = sim_state.fls
            params, hyperparams = planner_state.policy_params, planner_state.hyperparams
            cond = jax.nn.sigmoid(hyperparams['sigmoid'] * (fls['rlevel'] - params['T']))
            return {'release': cond * params['V1'] + (1 - cond) * params['V2']}
        self.train_policy = apply_train_fn

        # apply hard action for evaluation
        def apply_test_fn(sim_state, planner_state):
            fls, params = sim_state.fls, planner_state.policy_params
            cond = jnp.greater_equal(fls['rlevel'], params['T'])
            return {'release': jnp.where(cond, params['V1'], params['V2'])}
        self.test_policy = apply_test_fn

    def guess_next_epoch(self, params):
        return params

We can now initialize and train this policy as usual:

planner = JaxBackpropPlanner(rddl=env.model, plan=PWLinearPolicy())
agent = JaxOfflineController(planner, print_summary=False, train_seconds=20)
agent.evaluate(env, episodes=10)
[INFO] Compiler will cast pvars {'RES_CONNECT', 'CONNECTED_TO_SEA'} to float.
[INFO] Bounds of action-fluent <release> set to (array([0., 0., 0.], dtype=float32), array([100., 100., 100.], dtype=float32)).
 761 it |   -3125.80518 train |   -1566.51245 test |     -46.67517 [44 it] pgpe |      -2.20372 best | 5 status: 100%|██████████| 00:19 , 38.08it/s
{'mean': np.float64(-25.51491077008982),
 'median': np.float64(0.0),
 'min': np.float64(-201.65455700170355),
 'max': np.float64(0.0),
 'std': np.float64(59.96213937665172)}

This policy performs near optimally. Let’s print the T, V1 and V2 parameters:

print(agent.params)
{'T': Array([0.50757444, 1.3407402 , 0.00769584], dtype=float32), 'V1': Array([1.6996087, 1.7419431, 5.2087555], dtype=float32), 'V2': Array([1.7276638, 0.6471594, 1.8261414], dtype=float32)}

Finally, let’s visualize the policy:

if not os.path.exists('frames'):
    os.makedirs('frames')
recorder = MovieGenerator("frames", "reservoir", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/reservoir_0.gif') 
../_images/ced4ea6e1e5a4eba4b11930712fe74d9161f0a9a8acee2c470aef03194391d78.gif