pyRDDLGym-jax: JAX Compiler and Planner
In this tutorial, we discuss how a RDDL model can be compiled into a differentiable simulator using JAX. We also show how gradient ascent can be used to estimate optimal actions.
Requirements
This package requires Python 3.8+
pyRDDLGym>=2.0
tqdm>=4.66
jax>=0.4.12
optax>=0.1.9
dm-haiku>=0.0.10
tensorflow>=2.13.0
tensorflow-probability>=0.21.0
To run the hyper-parameter tuning, you will also need
bayesian-optimization>=1.4.3
Installing via pip
You can install pyRDDLGym-jax and all of its requirements via pip:
pip install pyRDDLGym-jax
Installing the Pre-Release Version via git
pip install git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Changing the Simulation Backend to JAX
By default, pyRDDLGym simulates using Python and stores the outputs of intermediate expressions in NumPy arrays. However, if additional structure such as gradients are required, or if simulation is slow using the default backend, the environment can be compiled using JAX by changing the backend:
import pyRDDLGym
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
env = pyRDDLGym.make("Cartpole_Continuous_gym", "0", backend=JaxRDDLSimulator)
This is designed to be interchangeable with the default backend, so the numerical results should be the same.
Note
All RDDL syntax (both new and old) is supported in the RDDL-to-JAX compiler.
Differentiable Planning: Deterministic Domains
The planning problem for a deterministic environment involves finding actions that maximize accumulated reward over a fixed horizon (an open-loop plan)
If the state and action spaces are continuous, and f and R are differentiable functions, it is possible to apply gradient ascent to optimize the actions directly as described in this paper. Given a learning rate \(\eta > 0\) and “guess” \(a_\tau\), gradient ascent produces an estimate of the optimal action \(a_\tau'\) at time \(\tau\) as
where the gradient of the reward at all times \(t \geq \tau\) can be computed using the chain rule:
In domains with stochastic transitions, an open-loop plan could be sub-optimal as it does not correct for deviations in the state from its expected course as anticipated during optimization. One solution is to recompute the plan periodically or after each decision epoch, and is often called “replanning”.
An alternative approach to replanning is to learn a policy network \(a_t \gets \pi_\theta(s_t)\) that maps the states to actions, such as a feed-forward neural network as explained in this paper.
Differentiable Planning: Stochastic Domains
A common problem of planning in stochastic domains is that the gradients are no longer well-defined. pyRDDLGym-jax works around this problem by using the reparameterization trick.
To illustrate, we can write \(s_{t+1} = \mathcal{N}(s_t, a_t^2)\) as \(s_{t+1} = s_t + a_t * \mathcal{N}(0, 1)\), although the latter is amenable to backpropagation while the first is not.
The reparameterization trick also works generally, assuming there exists a closed-form function f such that
and \(\xi_t\) are random variables drawn from some distribution independent of states or actions. For a detailed discussion of reparameterization in the context of planning, please see this paper or this one.
pyRDDLGym-jax automatically performs reparameterization whenever possible. For some special cases, such as the Bernoulli and Discrete distribution, it applies the Gumbel-softmax trick as described in this paper. Defining K independent samples from a standard Gumbel distribution \(g_1, \dots g_K\), we reparameterize the random variable \(X\) with probability mass function \(p_1, \dots p_K\) as
where the argmax is approximated using the softmax function.
Warning
For general non-reparameterizable distributions, the result of the gradient calculation is fully dependent on the JAX implementation: it could return a zero or NaN gradient, or raise an exception.
Running the Basic Example
A basic run script is provided to run the Jax Planner on any domain in rddlrepository, provided a config file of hyper-parameters is available (currently, only a limited subset of configs are provided). The example can be run as follows in a standard shell, from the install directory of pyRDDLGym-jax:
python -m pyRDDLGym_jax.examples.run_plan <domain> <instance> <method> <episodes>
where:
<domain>
is the domain identifier as specified in rddlrepository, or a path pointing to a valid domain.rddl file<instance>
is the instance identifier in rddlrepository, or a path pointing to a valid instance.rddl file<method>
is the planning method to use (see below)<episodes>
is the (optional) number of episodes to evaluate the learned policy.
The <method>
parameter warrants further explanation. Currently we support three possible modes:
slp
is the straight-line open-loop planner described in this paperdrp
is the deep reactive policy network described in this paperreplan
is the same asslp
except it uses periodic replanning as described above.
For example, copy and pasting the following will train the JAX Planner on the Quadcopter domain with 4 drones:
python -m pyRDDLGym_jax.examples.run_plan Quadcopter 1 slp
Running from the Python API
pyRDDLGym-jax provides convenient tools to automatically compile a RDDL description of a problem to the above optimization problem:
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOfflineController
# set up the environment (note the vectorized option must be True)
env = pyRDDLGym.make("domain", "instance", vectorized=True)
# create the planning algorithm
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
controller = JaxOfflineController(planner, **train_args)
# evaluate the planner
controller.evaluate(env, episodes=1, verbose=True, render=True)
env.close()
Here, we have used an open-loop (offline) controller. To use periodic replanning, simply change the controller type to online:
controller = JaxOnlineController(planner, **train_args)
Note
All controllers are instances of pyRDDLGym’s BaseAgent
and support the evaluate()
function.
The **planner_args
and **train_args
are keyword arguments passed during initialization,
but we strongly recommend creating and loading a configuration file as discussed next.
Writing Configuration Files for Custom Problems
The recommended way to load planner and training arguments is to write a configuration file with all the necessary hyper-parameters. The basic structure of a configuration file is provided below for open-loop planning or replanning:
[Model]
logic='FuzzyLogic'
logic_kwargs={'weight': 20}
tnorm='ProductTNorm'
tnorm_kwargs={}
[Optimizer]
method='JaxStraightLinePlan'
method_kwargs={}
optimizer='rmsprop'
optimizer_kwargs={'learning_rate': 0.001}
batch_size_train=1
batch_size_test=1
rollout_horizon=5
[Training]
key=42
epochs=5000
train_seconds=30
The configuration file contains three sections:
the
[Model]
section dictates how non-differentiable expressions are handled (as discussed later in the tutorial)the
[Optimizer]
section contains amethod
argument to indicate the type of plan/policy, its hyper-parameters, theoptax
SGD optimizer and its hyper-parameters, etc.the
[Training]
section indicates budget on iterations or time, hyper-parameters for the policy, etc.
The configuration file can then be parsed and passed to the planner as follows:
from pyRDDLGym_jax.core.planner import load_config
planner_args, _, train_args = load_config("/path/to/config.cfg")
# continue as described above
planner = ...
controller = ...
Note
The rollout_horizon
in the configuration file is optional, and defaults to the horizon specified in the RDDL description.
For replanning methods, we recommend setting this parameter manually, and tuning it to get the best result.
Writing Configuration Files for Policy Networks
To use a policy network instead of an open-loop plan or replanning,
change the method
in the [Optimizer]
section of the config file:
...
[Optimizer]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 64]}
...
This creates a neural network policy with the default ReLU activations, and two hidden layers with 128 and 64 neurons.
Note
JaxStraightlinePlan
and JaxDeepReactivePolicy
are instances of the abstract class JaxPlan
.
Other policy representations could be defined by overriding this class and its abstract methods.
Boolean Actions
By default, boolean actions are wrapped using the sigmoid function:
where \(\theta\) denotes the trainable action parameters, and \(w\) denotes a hyper-parameter that controls the sharpness of the approximation.
Warning
If the sigmoid wrapping is used, then the weights w
must be specified in
policy_hyperparams
for each boolean action fluent when interfacing with the planner.
At test time, the action is aliased by evaluating the expression \(a > 0.5\), or equivalently \(\theta > 0\).
The use of sigmoid for boolean actions can be disabled by setting wrap_sigmoid = False
, but this is not recommended.
Constraints on Action Fluents
Currently, the JAX planner supports two different kind of actions constraints: box constraints and concurrency constraints.
Box constraints are useful for bounding each action fluent independently into some range.
Box constraints typically do not need to be specified manually, since they are automatically
parsed from the action_preconditions
as defined in the RDDL domain description file.
However, if the user wishes, it is possible to override these bounds
by passing a dictionary of bounds for each action fluent into the action_bounds
argument.
The syntax for specifying optional box constraints in the [Optimizer]
section of the config file is:
[Optimizer]
...
action_bounds={ <action_name1>: (lower1, upper1), <action_name2>: (lower2, upper2), ... }
where lower#
and upper#
can be any list or nested list.
By default, the box constraints on actions are enforced using the projected gradient method.
An alternative approach is to map the actions to the box via a differentiable transformation,
as described by equation 6 in this paper.
In the JAX planner, it is possible to switch to the transformation method by setting wrap_non_bool = True
.
The JAX planner also supports concurrency constraints on actions of the form
\(\sum_i a_i \leq B\) for some constant \(B\).
If the max-nondef-actions
property in the RDDL instance is less
than the total number of boolean action fluents, then JaxRDDLBackpropPlanner
will automatically
apply a projected gradient step to ensure this constraint is satisfied at each optimization step, as described
in this paper.
Note
Concurrency constraints on action-fluents are applied to boolean actions only: e.g., real and int actions are ignored.
Reward Normalization
Some domains yield rewards that vary significantly in magnitude between time steps, making optimization difficult without some form of normalization. Following this paper, pyRDDLGym-jax can apply a symlog transform to the sampled rewards during backprop:
which compresses the magnitudes of large positive and negative outcomes.
The use of symlog can be enabled by setting use_symlog_reward = True
in JaxBackpropPlanner
.
Utility Optimization
By default, the JAX planner will optimize the expected sum of future reward, which may not be desirable for risk-sensitive applications. Following the framework in this paper, it is possible to optimize a non-linear utility of the return instead.
For example, the entropic utility with risk-aversion parameter \(\beta\) is
This can be passed to the planner as follows:
import jax.numpy as jnp
def entropic(x, beta=0.00001):
return (-1.0 / beta) * jnp.log(jnp.mean(jnp.exp(-beta * x)) + 1e-12)
planner = JaxRDDLBackpropPlanner(..., utility=entropic)
...
Changing the Planning Algorithm
In the introductory example, you may have noticed that we defined the planning algorithm separately from the controller.
Therefore, it is possible to incorporate new planning algorithms simply by extending the JaxBackpropPlanner
class.
pyRDDLGym-jax currently provides one such extension based on backtracking line-search, which adaptively selects a learning rate at each iteration whose gradient update provides the greatest improvement in the return objective.
This optimizer can be used as a drop-in replacement for JaxRDDLBackpropPlanner
as follows:
from pyRDDLGym_jax.core.planner import JaxRDDLArmijoLineSearchPlanner, JaxOfflineController
planner = JaxRDDLArmijoLineSearchPlanner(env.model, **planner_args)
controller = JaxOfflineController(planner, **train_args)
Like the default planner, the line-search planner is compatible with offline and online controllers, and straight-line plans and deep reactive policies.
Automatically Tuning Hyper-Parameters
pyRDDLGym-jax provides a Bayesian optimization algorithm for automatically tuning key hyper-parameters of the planner. It:
supports multi-processing by evaluating multiple hyper-parameter settings in parallel
leverages Bayesian optimization to perform more efficient search than random or grid search
supports straight-line planning and deep reactive policies
The key hyper-parameters can be tuned as follows:
import pyRDDLGym
from pyRDDLGym_jax.core.tuning import JaxParameterTuningSLP
# set up the environment
env = pyRDDLGym.make(domain, instance, vectorized=True)
# set up the tuning instance
tuning = JaxParameterTuningSLP(env=env,
train_epochs=epochs,
timeout_training=timeout,
eval_trials=trials,
planner_kwargs=planner_args,
plan_kwargs=plan_args,
num_workers=workers,
gp_iters=iters)
# tune and report the best hyper-parameters found
best = tuning.tune(key=key, filename="/path/to/log.csv")
print(f'best parameters found: {best}')
The __init__
method requires the num_workers
parameter to specify the
number of parallel processes and the gp_iters
to specify the number of iterations of Bayesian optimization.
Upon executing this code, a dictionary of the best hyper-parameters (e.g. learning rate, policy network architecture, model hyper-parameters, etc.) is returned. A log of the previous sets of hyper-parameters suggested by the algorithm is also recorded in the specified output file.
Policy networks and replanning can be tuned by replacing JaxParameterTuningSLP
with
JaxParameterTuningDRP
and JaxParameterTuningSLPReplan
, respectively.
This will also tune the architecture (number of neurons, layers) of the policy network and the rollout_horizon
for replanning.
Dealing with Non-Differentiable Expressions
Many RDDL programs contain expressions that do not support derivatives. A common technique to deal with this is to rewrite non-differentiable operations as similar differentiable ones. For instance, consider the following problem of classifying points (x, y) in 2D-space as +1 if they lie in the top-right or bottom-left quadrants, and -1 otherwise:
def classify(x, y):
if x > 0 and y > 0 or not x > 0 and not y > 0:
return +1
else:
return -1
Relational expressions such as x > 0
and y > 0
,
and logical expressions such as and
and or
do not have obvious derivatives.
To complicate matters further, the if
statement depends on both x
and y
so it does not have partial derivatives with respect to x
nor y
.
pyRDDLGym-jax works around these limitations by approximating such operations with JAX expressions that support derivatives.
For instance, the classify
function above could be implemented as follows:
from pyRDDLGym_jax.core.logic import FuzzyLogic
logic = FuzzyLogic()
And, _ = logic.And()
Not, _ = logic.Not()
Gre, _ = logic.greater()
Or, _ = logic.Or()
If, _ = logic.If()
def approximate_classify(x1, x2, w):
q1 = And(Gre(x1, 0, w), Gre(x2, 0, w), w)
q2 = And(Not(Gre(x1, 0, w), w), Not(Gre(x2, 0, w), w), w)
cond = Or(q1, q2, w)
return If(cond, +1, -1, w)
Calling approximate_classify
with x=0.5
, y=1.5
and w=10
returns 0.98661363, which is very close to 1.
The FuzzyLogic
instance can be passed to a planner through the config file, or directly as follows:
from pyRDDLGym.core.logic import FuzzyLogic
planner = JaxRDDLBackpropPlanner(model, ..., logic=FuzzyLogic())
By default, FuzzyLogic
uses the product t-norm
fuzzy logic to approximate the logical operations, the standard complement \(\sim a \approx 1 - a\), and
sigmoid approximations for other relational and functional operations.
The latter introduces model hyper-parameters \(w\), which control the “sharpness” of the operation. Higher values mean the approximation approaches its exact counterpart, at the cost of sparse and possibly numerically unstable gradients.
These can be retrieved and modified at run-time, such as during optimization, as follows:
model_params = planner.compiled.model_params
model_params[key] = ...
planner.optimize(..., model_params=model_params)
The following table summarizes the default rules used in FuzzyLogic
.
Exact RDDL Operation |
Approximate Operation |
---|---|
\(a \text{ ^ } b\) |
\(a * b\) |
\(\sim a\) |
\(1 - a\) |
forall_{?p : type} x(?p) |
\(\prod_{?p} x(?p)\) |
if (c) then a else b |
\(c * a + (1 - c) * b\) |
\(a == b\) |
\(\frac{\mathrm{sigmoid}(w * (a - b + 0.5)) - \mathrm{sigmoid}(w * (a - b - 0.5))}{\tanh(0.25 * w)}\) |
\(a > b\), \(a >= b\) |
\(\mathrm{sigmoid}(w * (a - b))\) |
\(\mathrm{signum}(a)\) |
\(\tanh(w * a)\) |
argmax_{?p : type} x(?p) |
\(\sum_{i = 1, 2, \dots |\mathrm{type}|} i * \mathrm{softmax}(w * x)[i]\) |
Bernoulli(p) |
Gumbel-Softmax trick |
Discrete(type, {cases …} ) |
Gumbel-Softmax trick |
It is possible to control these rules by subclassing FuzzyLogic
, or by
passing different values to the tnorm
or complement
arguments to replace the product t-norm logic and
standard complement, respectively.
Computing the Gradients Manually
The API also supports gradient calculation manually for custom applications. Please see the worked example here how to calculate the gradient of the return with respect to the policy parameters.
Limitations
We cite several limitations of the current JAX planner:
- Not all operations have natural differentiable relaxations. Currently, the following are not supported:
nested fluents such as
fluent1(fluent2(?p))
distributions that are not naturally reparameterizable such as Poisson, Gamma and Beta
- Some relaxations can accumulate high error
this is particularly problematic when stacking CPFs for long roll-out horizons, so we recommend reducing or tuning the rollout-horizon for best results
- Some relaxations may not be mathematically consistent with one another:
no guarantees are provided about dichotomy of equality, e.g. a == b, a > b and a < b do not necessarily “sum” to one, but in many cases should be close
if this is a concern, it is recommended to override some operations in
ProductLogic
to suit the user’s needs
Termination conditions and state/action constraints are not considered in the optimization (but can be checked at test-time).
- The optimizer can fail to make progress when the structure of the problem is largely discrete:
to diagnose this, compare the training loss to the test loss over time, and at the time of convergence
a low, or drastically improving, training loss with a similar test loss indicates that the continuous model relaxation is likely accurate around the optimum
on the other hand, a low training loss and a high test loss indicates that the continuous model relaxation is poor, in which case the optimality of the solution should be questioned.
The goal of the JAX planner was not to replicate the state-of-the-art, but to provide a simple baseline that can be easily built-on. However, we welcome any suggestions or modifications about how to improve this algorithm on a broader subset of RDDL.