pyRDDLGym-jax: Gradient-Based Simulation and Planning with JaxPlan#

In this tutorial, we discuss how a RDDL model can be automatically compiled into a differentiable JAX simulator. We also show how pyRDDLGym-jax (or JaxPlan) leverages gradient-based optimization for optimal control.

Installing#

To install the bare-bones version of JaxPlan with minimum installation requirements:

pip install pyRDDLGym-jax

To install JaxPlan with the automatic hyper-parameter tuning and rddlrepository:

pip install pyRDDLGym-jax[extra]

(Since version 1.0) To install JaxPlan with the visualization dashboard:

pip install pyRDDLGym-jax[dashboard]

(Since version 1.0) To install JaxPlan with all options:

pip install pyRDDLGym-jax[extra,dashboard]

To install the pre-release version via git:

pip install git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git

Simulating Environments using JAX#

pyRDDLGym ordinarily simulates domains using numPy. If you require additional structure such as gradients, or better simulation performance, switch to a JAX simulation backend:

import pyRDDLGym
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
env = pyRDDLGym.make("CartPole_Continuous_gym", "0", backend=JaxRDDLSimulator)

Note

All RDDL syntax (both new and old) is supported in the RDDL-to-JAX compiler. In almost all cases, the JAX backend should return numerical results identical to the default backend. However, not all operations can support gradients (see Limitations).

Jupyter Notebook Related example: Accelerating simulation with JAX.

Background on Differentiable Planning#

Open-Loop Planning#

The open-loop planning problem for a deterministic environment seeks a sequence of actions (plan) that maximize accumulated reward over a fixed horizon

\[\max_{a_1, \dots a_T} \sum_{t=1}^{T} R(s_t, a_t), \quad s_{t + 1} = f(s_t, a_t)\]

If the state and action spaces are continuous, and f and R are differentiable, gradient ascent can optimize the actions. Specifically, given learning rate \(\eta\), gradient ascent updates the plan \(a_\tau'\) at decision epoch \(\tau\) as

\[a_{\tau}' = a_{\tau} + \eta \sum_{t=1}^{T} \nabla_{a_\tau} R(s_t, a_t),\]

where the gradient of the reward at all times \(t \geq \tau\) is computed by automatic differentiation in JAX.

Closed-Loop Planning#

An open-loop plan could be sub-optimal by failing to correct for deviations in the state trajectory from its anticipated course. One solution is to “replan” periodically or at each decision epoch. Another solution is to compute a closed-loop deep reactive policy network \(a_t \gets \pi_\theta(s_t)\). JaxPlan supports both options.

Stochastic Reparameterization Trick#

A secondary problem is that the gradients of stochastic samples are not well-defined. JaxPlan works around this by using the reparameterization trick, i.e. writing \(s_{t+1} = \mathcal{N}(s_t, a_t^2)\) as \(s_{t+1} = s_t + a_t * \mathcal{N}(0, 1)\), where the latter is amenable to backprop while the first is not.

The reparameterization trick can be generalized, assuming there exists a closed-form function f such that

\[s_{t+1} = f(s_t, a_t, \xi_t)\]

and \(\xi_t\) are random variables drawn from some distribution independent of states and actions. For a detailed discussion of reparameterization in the context of planning, please see this paper or this paper.

JaxPlan automatically reparameterizes whenever possible. For Bernoulli, Discrete and related distributions on finite support, it applies the Gumbel-softmax trick. For other distributions without natural reparameterization (i.e. Poisson, Binomial), JaxPlan applies various differentiable relaxations to approximate the gradients.

Note

As of JaxPlan version 3.0, most discrete and continuous distributions support gradients (approximate when required). The notable exception is Multinomial which does not yet support non-zero gradients.

Running JaxPlan#

From the Command Line#

A command line app is provided to run JaxPlan on a specific problem instance:

jaxplan plan <domain> <instance> <method> --episodes <episodes>

where:

  • <domain> is the domain identifier in rddlrepository, or a path pointing to a valid domain file

  • <instance> is the instance identifier in rddlrepository, or a path pointing to a valid instance file

  • <method> is the planning method to use (i.e. drp, slp, replan) or a path to a valid config file

  • <episodes> is the (optional) number of episodes to evaluate the final policy.

The <method> parameter describes the type of planning representation:

For example, the following will execute an open-loop controller to fly 4 drones:

jaxplan plan Quadcopter 1 slp

From Python#

To initialize and run an open-loop controller in Python:

import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController

# set up the environment (note the vectorized option must be True)
env = pyRDDLGym.make("domain", "instance", vectorized=True)

# create the planning algorithm
plan = JaxStraightLinePlan(**plan_args)
planner = JaxBackpropPlanner(rddl=env.model, plan=plan, **planner_args)
controller = JaxOfflineController(planner, **train_args)

# evaluate the planner
controller.evaluate(env, episodes=1, verbose=True, render=True)
env.close()

The **plan_args, **planner_args and **train_args are keyword arguments passed during initialization, but we strongly recommend using configuration files as discussed in the next section.

Note

All controllers are instances of pyRDDLGym’s BaseAgent and support the evaluate() function.

Jupyter Notebook Related example: Open-loop planning with straightline plans in JaxPlan.

To use periodic replanning, simply change the controller type to:

controller = JaxOnlineController(planner, **train_args)
Jupyter Notebook Related example: Closed-loop replanning with JaxPlan. Jupyter Notebook Related example: POMDP planning from images with JaxPlan.

To use a deep reactive policy, simply change the plan type to:

plan = JaxDeepReactivePolicy(**plan_args)
Jupyter Notebook Related example: Closed-loop planning with deep reactive policies in JaxPlan.

Note

JaxStraightlinePlan and JaxDeepReactivePolicy are instances of the abstract class JaxPlan. Other policy representations could be defined by overriding this class and its abstract methods.

Configuring JaxPlan#

The recommended way to manage planner settings is to write a configuration file with all required hyper-parameters.

Configuration Files#

As of JaxPlan version 3.0, the configuration file contains three sections:

  • [Compiler] dictates how RDDL expressions are translated to JAX

  • [Planner] specifies the type of plan or policy, its hyper-parameters, optimizer, etc.

  • [Optimize] specifies budget on iterations, time limit, stopping rule, etc.

For straight-line planning, below is an example of a working configuration file:

[Compiler]
method='DefaultJaxRDDLCompilerWithGrad'
sigmoid_weight=20

[Planner]
method='JaxStraightLinePlan'
method_kwargs={}
optimizer='rmsprop'
optimizer_kwargs={'learning_rate': 0.001}

[Optimize]
key=42
epochs=5000
train_seconds=30

To use a policy network with two hidden layers of size 128:

[Planner]
method='JaxDeepReactivePolicy'
method_kwargs={'topology': [128, 128]}

To use replanning with a rollout horizon of 5:

[Optimize]
rollout_horizon=5

Expand the following sections to see which parameters can be set in each section (for version 3.0):

Possible config parameters under [Compiler]
[Compiler] settings for all JaxRDDLCompilerWithGrad instances#

Setting

Description

allow_synchronous_state

Whether next state variables allowed to depend on other next state variables

cpfs_without_grad

Set of cpfs whose gradients are to be ignored (use STE estimator)

method

Type of core.logic.JaxRDDLCompilerWithGrad defines translation from RDDL to JAX

print_warnings

Whether to print compilation warnings

stochastic_is_fluent

Whether traced stochastic nodes are seen as fluent even if all arguments are not

use64bit

Whether to use 64 bit arithmetic

[Compiler] settings for DefaultJaxRDDLCompilerWithGrad#

Setting

Description

argmax_weight

Controls strength of softmax relaxation of argmax and argmin operators

bernoulli_sigmoid_weight

Controls strength of sigmoid relaxation of Bernoulli

binomial_eps

Underflow correction of Binomial

binomial_nbins

Maximum bins for Binomial relaxation before switching to Normal approximation

binomial_softmax_weight

Controls strength of softmax relaxation of Binomial

discrete_eps

Underflow correction of Discrete

discrete_softmax_weight

Controls strength of softmax relaxation of Discrete

floor_weight

Controls strength of tanh relaxation of floor and ceil operators

geometric_eps

Underflow correction of Geometric

geometric_floor_weight

Controls strength of tanh relaxation of floor operator in Geometric

poisson_comparison_weight

Controls strength of exponential approximation of Poisson

poisson_min_cdf

Controls when to use exponential or Normal approximation of Poisson

poisson_nbins

Maximum bins for Poisson relaxation before switching to Normal approximation

round_weight

Controls strength of tanh relaxation of round operators

sigmoid_weight

Controls strength of sigmoid/tanh relaxation of relational operators

sqrt_eps

Underflow correction of sqrt operators

switch_weight

Controls strength of softmax relaxation of switch operators

use_floor_ste

Whether to use STE for floor relaxation

use_if_else_ste

Whether to use STE for if-then-else relaxation

use_logic_ste

Whether to use STE for relaxation of logical operators

use_round_ste

Whether to use STE for round relaxation

use_sigmoid_ste

Whether to use STE for sigmoid-relaxed operators

use_tanh_ste

Whether to use STE for tanh-relaxed operators (e.g. sign)

Possible config parameters under [Planner]
[Planner]#

Setting

Description

action_bounds

Dict of (lower, upper) bound tensors for each action-fluent

batch_size_test

Batch size for evaluation

batch_size_train

Batch size for training

clip_grad

Bound on gradient magnitude

dashboard

Whether to show a dashboard with training progress

ema_decay

Decay rate of EMA of policy parameters

line_search_kwargs

Arguments for zoom line search

method

Type of core.planner.JaxPlan specifies the policy class

method_kwargs

Arguments for policy constructor (see next tables for options)

noise_kwargs

Arguments for gradient noise

optimizer

Name of optimizer from optax

optimizer_kwargs

Arguments for optimizer constructor such as learning_rate

parallel_updates

Number of independent policies to optimize in parallel

pgpe

Type of core.planner.PGPE for parameter-exploring policy gradient update

pgpe_kwargs

Arguments for PGPE constructor (see table below for default choices)

preprocessor

Type of core.planner.Preprocessor for input preprocessing such as normalization

preprocessor_kwargs

Arguments for preprocessor constructor

rollout_horizon

Rollout horizon of the computation graph

use_symlog_reward

Whether to apply symlog transform to returns

utility

Utility function to optimize

utility_kwargs

Arguments for utility such as hyper-parameters

Possible method_kwargs arguments for JaxStraightLinePlan#

Setting

Description

initializer

Type of jax.nn.initializers

initializer_kwargs

Arguments for initializer constructor

max_constraint_iter

Maximum iterations of gradient projection

min_action_prob

Minimum bound on boolean action to avoid sigmoid saturation

use_new_projection

Whether to use new sorting gradient projection for boolean action preconditions

wrap_non_bool

Whether to wrap non-boolean actions with nonlinearity for box constraints

wrap_sigmoid

Whether to wrap boolean actions with sigmoid

wrap_softmax

Whether to wrap boolean actions with softmax to satisfy max-nondef-actions

Possible method_kwargs arguments for JaxDeepReactivePolicy#

Setting

Description

activation

Activation for hidden layers in jax.numpy or jax.nn

initializer

Type of haiku.initializers

initializer_kwargs

Arguments for initializer constructor

normalize

Whether to apply layer norm to inputs

normalize_per_layer

Whether to apply layer norm to each input individually

normalizer_kwargs

Arguments for haiku.LayerNorm constructor

softmax_output_weight

Weight for softmax projection in cardinality constraints

time_dependent

Whether the policy is time dependent

time_embedding_dim

Dimension of the time embedding when time_dependent

topology

List specifying number of neurons per hidden layer

wrap_non_bool

Whether to wrap non-boolean actions with nonlinearity for box constraints

Possible pgpe_kwargs arguments for GaussianPGPE#

Setting

Description

batch_size

Number of parameters to sample per gradient descent step

end_entropy_coeff

Ending entropy regularization coeffient

init_sigma

Initial standard deviation of meta policy

max_kl_update

Maximum bound on kl-divergence between successive updates

min_reward_scale

Minimum scaling factor for scale_reward

optimizer

Name of optimizer from optax

optimizer_kwargs_mu

Arguments for optimizer constructor for mean such as learning_rate

optimizer_kwargs_sigma

Arguments for optimizer constructor for std such as learning_rate

scale_reward

Whether to apply reward scaling in parameter updates

sigma_range

Clipping bounds for standard deviation of meta policy

start_entropy_coeff

Starting entropy regularization coeffient

super_symmetric

Whether to use super-symmetric sampling for standard deviation

super_symmetric_accurate

Whether to use the accurate formula for super symmetric sampling in the paper

Possible config parameters under [Optimize]
[Optimize]#

Setting

Description

epochs

Maximum number of iterations

key

RNG seed for JAX

model_params

Dict of hyper-parameter values for the model relaxation

policy_hyperparams

Dict of hyper-parameter values for the policy

print_hyperparams

Whether to print the planner hyper-parameters

print_progress

Whether to show the progress bar

print_summary

Whether to print the planner summary

stopping_rule

Type of JaxPlannerStoppingRule for stopping the optimizer

stopping_rule_kwargs

Arguments for stopping rule constructor

test_rolling_window

Smoothing window for test return calculation

train_seconds

Maximum seconds to iterate

Using Configuration Files#

Configuration files can be parsed and passed to the plan, planner and controller as in the basic example:

from pyRDDLGym_jax.core.planner import load_config
planner_args, plan_args, train_args = load_config("/path/to/config")
# continue to initialize plan, planner and controller
...

Constraints on Action-Fluents#

Boolean Action-Fluents#

By default, boolean actions are wrapped using the sigmoid function:

\[a = \frac{1}{1 + e^{-w \theta}},\]

where \(\theta\) are the trainable action parameters and \(w\) is a hyper-parameter controlling the sharpness. At test time, the action is aliased by evaluating the expression \(a > 0.5\), or equivalently \(\theta > 0\). This setting can be controlled in JaxPlan by setting wrap_sigmoid.

Warning

If wrap_sigmoid = True, then w should be specified in policy_hyperparams dictionary per boolean action fluent.

Box Constraints#

Box constraints are useful for bounding each action fluent independently within some range. Box constraints typically do not need to be specified manually, since they are automatically parsed from the action_preconditions in the RDDL domain description.

However, it is possible to override these bounds by passing a dictionary of bounds for each action fluent into the action_bounds argument. The syntax for specifying optional box constraints in the config is:

[Optimize]
action_bounds={ <action_fluent1>: (lower1, upper1), <action_fluent2>: (lower2, upper2), ... }

where lower# and upper# can be any list, nested list or array.

Note

By default, box constraints are enforced using projected gradient. An alternative approach applies a differentiable transformation to action fluents. In JaxPlan, this can be controlled by setting wrap_non_bool.

Concurrency#

Cardinality constraints are of the form \(\sum_i a_i \leq B\) where \(B\) is max-nondef-actions in the RDDL instance.

Note

For SLPs, JaxBackpropPlanner will automatically apply projected gradient to satisfy constraints at each optimization step. For DRPs, JaxBackpropPlanner will automatically use a differentiable top-k projection.

Automatically Tuning Hyper-Parameters#

JaxPlan provides a Bayesian optimization algorithm for automatically tuning hyper-parameters:

  • supports multi-processing by evaluating multiple hyper-parameter settings in parallel

  • leverages Bayesian optimization to search the hyper-parameter space more efficiently

  • supports all types of policies that use config files.

From the Command Line#

The command line app runs the automated tuning on several key hyper-parameters:

jaxplan tune <domain> <instance> <method> <trials> <iters> <workers> <dashboard>

where:

  • domain and instance specify the domain and instance names

  • method is the planning method (i.e., slp, drp, replan)

  • trials is the (optional) number of trials/episodes to average in evaluating each hyper-parameter setting

  • iters is the (optional) maximum number of iterations/evaluations of Bayesian optimization to perform

  • workers is the (optional) number of parallel evaluations to be done at each iteration, e.g. maximum total evaluations is trials * workers

  • dashboard is whether the optimizations are tracked and displayed in a dashboard application.

From Python#

To customize the hyper-parameter tuning algorithm in detail, first create an abstract config file, where concrete hyper-parameters to tune are replaced by keywords. To tune the sigmoid relaxation in the compiler and the optimizer learning rate, for example:

[Compiler]
method='DefaultJaxRDDLCompilerWithGrad'
sigmoid_weight=TUNABLE_WEIGHT

[Planner]
method='JaxStraightLinePlan'
method_kwargs={}
optimizer='rmsprop'
optimizer_kwargs={'learning_rate': TUNABLE_LEARNING_RATE}
...

Warning

During tuning, keywords are replaced by concrete values via simple string matching. Therefore, you must select keywords not appearing (as substrings) in any other parts of the config file.

Next, for each config variable, specify its search range and transformation to apply:

from pyRDDLGym_jax.core.tuning import JaxParameterTuning, Hyperparameter
from pyRDDLGym_jax.core.planner import load_config_from_string

# load env as usual
...

# load the abstract config file with planner settings
with open('path/to/config', 'r') as file:
    config_template = file.read()

# map parameters in the config that will be tuned
def power_10(x):
    return 10.0 ** x
hyperparams = [Hyperparameter("TUNABLE_WEIGHT", -1., 5., power_10),
               Hyperparameter("TUNABLE_LEARNING_RATE", -5., 1., power_10)]

# build the tuner and tune (online indicates not to use replanning)
tuning = JaxParameterTuning(env=env, config_template=config_template, hyperparams=hyperparams,
                            online=False, eval_trials=trials, num_workers=workers, gp_iters=iters)
tuning.tune(key=42, log_file="path/to/logfile")

# parse the concrete config file with the best tuned values, and evaluate as usual
planner_args, _, train_args = load_config_from_string(tuning.best_config)
...

JaxPlan supports tuning most numeric parameters in the config file. If you wish to tune the replanning mode set online=True.

Possible settings for ``JaxParameterTuning``
JaxParameterTuning constructor arguments#

Setting

Description

acquisition

AcquisitionFunction object for the Gaussian process

config_template

Config file content with abstract parameters to tune as described above

env

The RDDLEnv instance

eval_trials

Number of independent trials/rollouts to evaluate each hyper-parameter combination

gp_init_kwargs

Optional keyword arguments to pass to the Gaussian process constructor

gp_iters

Number of rounds of tuning to perform

gp_params

Optional additional keyword arguments to pass to the Gaussian process (i.e. kernel)

hyperparams

List of Hyperparameter objects

num_workers

Number of parallel evaluations to perform in each round of tuning

online

Whether to use replanning mode for tuning

poll_frequency

How often to check for completed processes (defaults to 0.2 seconds)

pool_context

The type of pool context for multiprocessing (defaults to “spawn”)

rollouts_per_trial

For online=False, how many evaluation rollouts to perform per eval_trial

timeout_tuning

Maximum amount of time to allocate to tuning

verbose

Whether to print intermediate results to the standard console

Jupyter Notebook Related example: Tuning policy hyper-parameters in JaxPlan.

VIsualizing with Dashboard#

As of version 1.0, the embedded visualization tools have been replaced with a plotly dashboard, offering a more comprehensive way to introspect trained policies. To activate the dashboard for planning, simply add the following line in the config file:

[Planner]
dashboard=True

Risk-Aware Planning with Utility Optimization#

By default, JaxPlan will optimize the expected discounted sum of future reward, which may not be desirable for risk-sensitive applications. JaxPlan can also optimize a subset of non-linear utility functions:

  • “mean” is the risk-neutral or ordinary expected return

  • “mean_std” is the standard deviation penalized return

  • “mean_var” is the variance penalized return

  • “mean_semidev” is the mean-semideviation risk measure

  • “mean_semivar” is the mean-semivariance risk measure

  • “sharpe” is the sharpe ratio

  • “entropic” (or “exponential”) is the entropic or exponential utility

  • “var” is the value at risk

  • “cvar” is the conditional value at risk.

A utility function can be specified by passing a string above to the utility argument of the planner, and optional hyper-parameters dict to the utility_kwargs argument, i.e. for CVAR at 5 percent:

[Planner]
utility='cvar'
utility_kwargs={'alpha': 0.05}

The utility function could also be provided explicitly as a function mapping a JAX array to a scalar, with additional arguments specifying the hyper-parameters of the utility function referred to by name:

@jax.jit
def my_utility_function(x, aversion: float=1.0) -> float:
    return ...
planner = JaxBackpropPlanner(..., utility=my_utility_function, utility_kwargs={'aversion': 2.0})
Jupyter Notebook Related example: Risk-aware planning with RAPTOR in JaxPlan.

Dealing with Non-Differentiability#

Model Relaxations#

Many RDDL programs contain expressions that do not support derivatives. A common technique to deal with this is to approximate non-differentiable operations using similar differentiable ones.

For instance, consider the following problem of classifying points (x, y) in 2D-space as +1 if they lie in the top-right or bottom-left quadrants, and -1 otherwise:

def classify(x, y):
    if x > 0 and y > 0 or not x > 0 and not y > 0:
        return +1
    else:
        return -1

Relational expressions such as x > 0 and y > 0, and logical expressions such as and and or do not have obvious derivatives. To complicate matters further, the if statement depends on both x and y so it does not have partial derivatives with respect to x nor y.

JaxPlan works around these limitations by approximating such operations with JAX expressions that support derivatives. The JaxRDDLCompilerWithGrad describes how relaxations are performed, and it is highly configurable and inheritable. The type of compiler instance can be passed to a planner by specifying:

[Compiler]
method='MyJaxRDDLCompilerWithGradType'
method_kwargs=...

The default DefaultJaxRDDLCompilerWithGrad implements a variety of differentiable relaxations from the literature that have been carefully tuned for the best possible results, but they are also constantly improving with each new release.

Default DefaultJaxRDDLCompilerWithGrad rules#

Exact RDDL Operation

Approximate Operation

^, &, |, ~, forall, exists, etc.

Fuzzy t-norm logic

==, >, <, >=, <=, sgn, etc.

Tanh and Sigmoid

argmax, argmin

Softmax

floor, div, mod, etc.

SoftFloor

round

SoftRound

if-then-else

Linear

switch

Softmax

Bernoulli, Discrete

Gumbel-Softmax or Sigmoid

Geometric

SoftFloor

Binomial

Gumbel-Softmax for small population, Normal for large population

Poisson

rsample for small rate, Normal for large rate

Some relaxations naturally introduce hyper-parameters to control the quality of the approximation. These hyper-parameters can be retrieved and modified programmatically as follows:

model_params = planner.compiled.model_params
model_params[key] = ...
planner.optimize(..., model_params=model_params)

Parameter-Exploring Policy Gradient#

Since version 2.0, JaxPlan runs a parallel instance of parameter-exploring policy gradients (PGPE). In some cases, this allows JaxPlan to continue making progress when the model relaxations are poor or the gradient descent optimizer fails to make progress.

It is enabled by default, but can be configured in the config file as follows:

[Planner]
pgpe='GaussianPGPE'
pgpe_kwargs={'optimizer_kwargs_mu': {'learning_rate': 0.01}, 'optimizer_kwargs_sigma': {'learning_rate': 0.01}}

Third-Party Optimizers#

Gradient-free methods such as global optimization could work when gradients are uninformative. As of version 0.3, it is possible to export the optimization problem to be solved by another optimizer such as scipy:

loss_fn, grad_fn, guess, unravel_fn = planner.as_optimization_problem()

The loss function loss_fn and gradient map grad_fn express policy parameters as 1D numpy arrays, so they can be used as inputs for other packages that do not make use of JAX. The unravel_fn allows the 1D array to be mapped back to a JAX pytree.

Jupyter Notebook Related example: Building an optimization problem for third-party optimizers.

Limitations#

We cite several limitations of the current version of JaxPlan:

  • Not all operations have natural differentiable relaxations or are supported by the compiler:
    • nested fluents such as fluent1(fluent2(?p))

    • Multinomial sampling

  • Some relaxations can accumulate high error:
    • particularly problematic for long rollout horizon, so we recommend reducing or tuning it

    • model relaxations and hyper-parameters can be tuned for optimal results

  • Some relaxations can not be mathematically consistent with one another:
    • dichotomy of equality, e.g. a == b, a > b and a < b do not necessarily “sum” to one, but in most cases should be close

    • it is recommended to override operations in the compiler if this is a concern

  • Termination conditions and complex (i.e. nonlinear) state or action constraints are not included in the optimization:
    • constraints can be logged in the optimizer callback and used during optimization (e.g. to build lagrangians)

  • Optimizer can fail to make progress when the problem is largely discrete:
    • to diagnose, monitor and compare the training loss and the test loss over time

The goal of JaxPlan is to provide a standard planning baseline that can be easily built upon. We also welcome any suggestions or modifications about how to improve the robustness of JaxPlan on a broader subset of RDDL.

Citation#

If you use the code provided by JaxPlan, please use the following bibtex for citation:

@inproceedings{
    gimelfarb2024jaxplan,
    title={JaxPlan and GurobiPlan: Optimization Baselines for Replanning in Discrete and Mixed Discrete and Continuous Probabilistic Domains},
    author={Michael Gimelfarb and Ayal Taitler and Scott Sanner},
    booktitle={34th International Conference on Automated Planning and Scheduling},
    year={2024},
    url={https://openreview.net/forum?id=7IKtmUpLEH}
}