Calling External Functions in pyRDDLGym.#
This advanced notebook discusses how to execute external function calls from RDDL domain description files. We use this feature to build a Pong visual environment for doing control from pixels.
Note: it is strongly recommended to run this in an environment with GPU support.
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository
%pip install --quiet -U "jax[cuda13]" optax flax tqdm
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import pyRDDLGym
from pyRDDLGym.core.policy import RandomAgent
from rddlrepository.core.manager import RDDLRepoManager
Training to Generate Pong Images from State#
We first train a visual decoder to reconstruct Pong screen images from numerical states. Let’s start by collecting some data of (state, image) pairs.
# resize image to 64-by-64, grayscale and in [0, 1]
def preprocess_frame(img, size=(64, 64)):
img = img.convert("L").resize(size, Image.BILINEAR)
arr = np.asarray(img, dtype=np.float32) / 255.0
return 1 - arr
env = pyRDDLGym.make('Pong_arcade', '0')
env._visualizer._ball_radius *= 2
controller = RandomAgent(env.action_space, seed=42)
states, actions, obss = [], [], []
for _ in range(40):
state, _ = env.reset()
controller.reset()
for _ in range(100):
obs = preprocess_frame(env.render())
action = controller.sample_action(state)
next_state, *_ = env.step(action)
states.append(np.concatenate([np.reshape(v, (-1,)) for k, v in state.items()
if k in {'ball-x___b1', 'ball-y___b1', 'paddle-y'}]))
actions.append(np.concatenate([np.reshape(v, (-1,)) for v in action.values()]))
obss.append(obs)
state = next_state
if np.any(state['ball-x___b1'] > 1.5):
break
env.close()
states, actions, obss = np.array(states), np.array(actions), np.array(obss)
Let’s inspect the shape of data:
print(f'state = {states.shape}')
print(f'action = {actions.shape}')
print(f'obs = {obss.shape}')
state = (2284, 3)
action = (2284, 1)
obs = (2284, 64, 64)
Next, we will build a decoder to map a state to a 64-by-64 image:
class PongDecoder(nn.Module):
@nn.compact
def __call__(self, state):
o = nn.Dense(256)(state)
o = nn.relu(o)
o = nn.Dense(16*16*32)(o)
o = nn.relu(o)
o = o.reshape((o.shape[0], 16, 16, 32))
o = nn.ConvTranspose(16, kernel_size=(3, 3), strides=(2, 2))(o)
o = nn.relu(o)
o = nn.ConvTranspose(1, kernel_size=(3, 3), strides=(2, 2))(o)
o = nn.sigmoid(o)
return o
decoder = PongDecoder()
decoder_fn = jax.jit(decoder.apply)
Let’s build the loss function as MSE of prediction vs actual image and update rule:
def loss_fn(params, states, obs):
return jnp.mean(jnp.square(obs - decoder.apply(params, states)))
optimizer = optax.adam(0.005)
@jax.jit
def update_fn(params, opt_state, states, obs):
loss, grads = jax.value_and_grad(loss_fn)(params, states, obs)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
Next, we need a mini-batch sampler over the original data:
def batch_sampler(batch_size=16):
indices = np.arange(states.shape[0])
np.random.default_rng().shuffle(indices)
for i in range(0, len(indices) - batch_size, batch_size):
idx = indices[i:i+batch_size]
yield (states[idx], obss[idx][..., None])
Finally, let’s train the model:
params = decoder.init(jax.random.PRNGKey(0), next(batch_sampler())[0])
opt_state = optimizer.init(params)
for ep in (pbar := tqdm(range(500))):
ep_loss = 0.
for (state, obs) in batch_sampler():
params, opt_state, loss = update_fn(params, opt_state, state, obs)
ep_loss += float(loss)
pbar.set_postfix({"loss": ep_loss})
100%|██████████| 500/500 [05:48<00:00, 1.43it/s, loss=0.00668]
Building a RDDL Description File with External Function Calls#
Now that we have the trained parameters, we need to write a function with a signature that will match the RDDL description. A JIT compiled function can also be used.
@jax.jit
def external_obs_function(ball_x, ball_y, paddle_y):
state = jnp.concatenate([jnp.reshape(ball_x, (-1,)),
jnp.reshape(ball_y, (-1)),
jnp.reshape(paddle_y, (-1,))])[None]
return decoder_fn(params, state).reshape((64, 64))
Next, we create a modified Pong RDDL description file with an observ-fluent taking values from the external function call $MyObsFunction[?r, ?c](ball-x'(_), ball-y'(_), paddle-y'). The output of this function must be assigned to an “image” fluent parameterized by rows and columns ?r, ?c. We need to explicitly enumerate these as objects in RDDL:
domain_text = """
domain pong_pomdp {
types {
ball : object;
rc : {@i1, @i2, @i3, @i4, @i5, @i6, @i7, @i8, @i9, @i10, @i11, @i12, @i13, @i14,
@i15, @i16, @i17, @i18, @i19, @i20, @i21, @i22, @i23, @i24, @i25, @i26,
@i27, @i28, @i29, @i30, @i31, @i32, @i33, @i34, @i35, @i36, @i37, @i38,
@i39, @i40, @i41, @i42, @i43, @i44, @i45, @i46, @i47, @i48, @i49, @i50,
@i51, @i52, @i53, @i54, @i55, @i56, @i57, @i58, @i59, @i60, @i61, @i62,
@i63, @i64};
};
pvariables {
NOISE-X(ball) : { non-fluent, real, default = 0.01 };
NOISE-Y(ball) : { non-fluent, real, default = 0.03 };
PADDLE-H : { non-fluent, real, default = 0.2 };
PADDLE-MAX-STEP : { non-fluent, real, default = 0.04 };
ball-x(ball) : { state-fluent, real, default = 0.5 };
ball-y(ball) : { state-fluent, real, default = 0.5 };
vel-x(ball) : { state-fluent, real, default = 0.03 };
vel-y(ball) : { state-fluent, real, default = 0.01 };
paddle-y : { state-fluent, real, default = 0.4 };
image(rc, rc) : { observ-fluent, real };
new-x(ball) : { interm-fluent, real };
new-y(ball) : { interm-fluent, real };
ball-crossing-y(ball) : { interm-fluent, real };
contact(ball) : { interm-fluent, bool };
move : { action-fluent, int, default = 0 };
};
cpfs {
// update position before contact and bounce
new-x(?b) = ball-x(?b) + vel-x(?b);
new-y(?b) = ball-y(?b) + vel-y(?b);
// check if the ball contacts the paddle
ball-crossing-y(?b) = ball-y(?b) + (vel-y(?b) * sgn[vel-x(?b)] / max[abs[vel-x(?b)], 0.03]) * (1.0 - ball-x(?b)) - paddle-y;
contact(?b) = (ball-x(?b) < 1.0) ^ (new-x(?b) >= 1.0)
^ (ball-crossing-y(?b) >= 0.0)
^ (ball-crossing-y(?b) <= PADDLE-H);
// update position
ball-x'(?b) = if (contact(?b)) then 2.0 - new-x(?b)
else if (new-x(?b) < 0.0) then -new-x(?b)
else new-x(?b);
ball-y'(?b) = if (new-y(?b) < 0.0) then -new-y(?b)
else if (new-y(?b) > 1.0) then 2.0 - new-y(?b)
else new-y(?b);
// update velocity
vel-x'(?b) = if (contact(?b) | new-x(?b) < 0.0) then -vel-x(?b) else vel-x(?b);
vel-y'(?b) = if (contact(?b)) then max[min[vel-y(?b) + Uniform(-NOISE-Y(?b), NOISE-Y(?b)), 1.0], -1.0]
else if (new-y(?b) < 0.0 | new-y(?b) > 1.0) then -vel-y(?b)
else vel-y(?b);
// update paddle position
paddle-y' = max[min[paddle-y + move * PADDLE-MAX-STEP, 1.0 - PADDLE-H], 0.0];
// draw the image from an external function call
image(?r, ?c) = $MyObsFunction[?r, ?c](ball-x'(_), ball-y'(_), paddle-y');
};
reward = -(sum_{?b : ball} ball-x(?b));
state-invariants {
forall_{?b : ball} [ball-x(?b) >= 0.0];
forall_{?b : ball} [ball-y(?b) >= 0.0];
forall_{?b : ball} [vel-x(?b) >= -1.0 ^ vel-x(?b) <= 1.0];
forall_{?b : ball} [vel-y(?b) >= -1.0 ^ vel-y(?b) <= 1.0];
paddle-y >= 0.0 ^ paddle-y <= 1.0 - PADDLE-H;
};
action-preconditions {
move >= -1 ^ move <= 1;
};
}
"""
instance_text = """
non-fluents pong {
domain = pong_pomdp;
objects {
ball : {b1};
};
}
instance pong_pomdp_0 {
domain = pong_pomdp;
non-fluents = pong;
max-nondef-actions = pos-inf;
horizon = 200;
discount = 1.0;
}
"""
Now we register this domain with rddlrepository as usual:
manager = RDDLRepoManager(rebuild=True)
manager.register_context("pong")
manager.register_domain("PongPOMDP", "pong", domain_text, desc="a visual pong domain", viz=None)
manager.get_problem("PongPOMDP_pong").register_instance("0", instance_text)
_ = RDDLRepoManager(rebuild=True)
Context <pong> was successfully registered in rddlrepository.
Domain <PongPOMDP> was successfully registered in rddlrepository with context <pong>.
Instance <0> was successfully registered in rddlrepository for domain <PongPOMDP_pong>.
Finally, to instantiate the domain, we must explicitly provide the external_obs_function defined earlier and map this to $MyObsFunction in the RDDL:
env = pyRDDLGym.make('PongPOMDP_pong', '0', vectorized=True, backend_kwargs={
'python_functions': {'MyObsFunction': external_obs_function}})
Let’s visualize this environment:
%matplotlib inline
fig, axes = plt.subplots(ncols=6, nrows=4, figsize=(12, 8))
axes = axes.flatten()
state, _ = env.reset()
for i in range(len(axes)):
action = {'move': int(np.round(np.random.uniform(-1., 1.)))}
next_state, *_ = env.step(action)
state = next_state
axes[i].imshow(state['image'], cmap='Greys')
axes[i].axis('off')
env.close()
plt.tight_layout()
plt.show()
Congratulations, you have now built a (differentiable) simulator for learning Pong from pixels!