Accelerating simulation with JAX.#

In this example, we show how to compile and simulate a pyRDDLGym environment using the JAX backend.

%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import time
import matplotlib.pyplot as plt
import numpy as np

import pyRDDLGym
from pyRDDLGym.core.policy import RandomAgent
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator

Does the JAX backend produce the same simulation result as the pyRDDLGym simulator?#

We demonstrate that JAX produces the same outputs as the regular pyRDDLGym simulation backend. Let’s make two identical copies of the same environment deterministic environment:

env = pyRDDLGym.make('RaceCar_ippc2023', '3', backend=JaxRDDLSimulator)
base_env = pyRDDLGym.make('RaceCar_ippc2023', '3')

Let’s generate actions using the random agent:

agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)

Let’s simulate from both copies of the environment with the same random seed and compare the trajectories:

state, _ = env.reset()
base_state, _ = base_env.reset()
max_diff = 0.0
for step in range(env.horizon):
    max_diff = max(max_diff, max(abs(state[key] - base_state[key]) for key in state))
    action = agent.sample_action(state)
    state, reward, *_ = env.step(action)
    base_state, base_reward, *_ = base_env.step(action)
print(f'the maximum difference in the state was {max_diff}')
the maximum difference in the state was 1.5356283866729825e-07

Therefore, both backends yield very similar results, with the difference due to the way in which JAX handles floating point precision.

Does the JAX backend run faster?#

Finally, let’s run all 5 instances and time how long it takes to simulate using pyRDDLGym vs the JAX backend:

times, base_times = [], []
for instance in range(1, 6):
    env = pyRDDLGym.make('RaceCar_ippc2023', str(instance), backend=JaxRDDLSimulator)
    start = time.time()
    agent.evaluate(env, episodes=50, render=False)
    times.append(time.time() - start)
    
    base_env = pyRDDLGym.make('RaceCar_ippc2023', str(instance))
    base_start = time.time()
    agent.evaluate(base_env, episodes=50, render=False)
    base_times.append(time.time() - base_start)

Let’s plot this to see the trends more clearly:

%matplotlib inline
plt.bar(np.arange(1, 6) - 0.2, times, 0.4, label = 'JAX') 
plt.bar(np.arange(1, 6) + 0.2, base_times, 0.4, label = 'numpy') 
plt.legend() 
plt.show() 
../_images/b12f89f3cefeb63d11811b4572436867a3e922bce099e5c8b5b993137babce98.png

Therefore, JAX can simulate faster than the default simulation backend.