Accelerating simulation with JAX.#
In this example, we show how to compile and simulate a pyRDDLGym environment using the JAX backend.
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import time
import matplotlib.pyplot as plt
import numpy as np
import pyRDDLGym
from pyRDDLGym.core.policy import RandomAgent
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
Does the JAX backend produce the same simulation result as the pyRDDLGym simulator?#
We demonstrate that JAX produces the same outputs as the regular pyRDDLGym simulation backend. Let’s make two identical copies of the same environment deterministic environment:
env = pyRDDLGym.make('RaceCar_ippc2023', '3', backend=JaxRDDLSimulator)
base_env = pyRDDLGym.make('RaceCar_ippc2023', '3')
Let’s generate actions using the random agent:
agent = RandomAgent(action_space=env.action_space, num_actions=env.max_allowed_actions)
Let’s simulate from both copies of the environment with the same random seed and compare the trajectories:
state, _ = env.reset()
base_state, _ = base_env.reset()
max_diff = 0.0
for step in range(env.horizon):
max_diff = max(max_diff, max(abs(state[key] - base_state[key]) for key in state))
action = agent.sample_action(state)
state, reward, *_ = env.step(action)
base_state, base_reward, *_ = base_env.step(action)
print(f'the maximum difference in the state was {max_diff}')
the maximum difference in the state was 1.5356283866729825e-07
Therefore, both backends yield very similar results, with the difference due to the way in which JAX handles floating point precision.
Does the JAX backend run faster?#
Finally, let’s run all 5 instances and time how long it takes to simulate using pyRDDLGym vs the JAX backend:
times, base_times = [], []
for instance in range(1, 6):
env = pyRDDLGym.make('RaceCar_ippc2023', str(instance), backend=JaxRDDLSimulator)
start = time.time()
agent.evaluate(env, episodes=50, render=False)
times.append(time.time() - start)
base_env = pyRDDLGym.make('RaceCar_ippc2023', str(instance))
base_start = time.time()
agent.evaluate(base_env, episodes=50, render=False)
base_times.append(time.time() - base_start)
Let’s plot this to see the trends more clearly:
%matplotlib inline
plt.bar(np.arange(1, 6) - 0.2, times, 0.4, label = 'JAX')
plt.bar(np.arange(1, 6) + 0.2, base_times, 0.4, label = 'numpy')
plt.legend()
plt.show()
Therefore, JAX can simulate faster than the default simulation backend.