Partially-Observed Planning from Pixels with JaxPlan

Partially-Observed Planning from Pixels with JaxPlan#

In this notebook, we use image segmentation to plan from pixels in JaxPlan. The pong domain is used as illustrative example.

State Inference using Image Segmentation from Pong Pixels#

Install the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet opencv-python pyRDDLGym rddlrepository pyRDDLgym-jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import matplotlib.pyplot as plt
from IPython.display import Image
import cv2
import numpy as np

import pyRDDLGym
from pyRDDLGym.core.policy import RandomAgent
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxBackpropPlanner, JaxOnlineController, load_config_from_string

We use image segmentation to identify the ball (x, y) coordinates and the paddle y-coordinate from the frame.

def segment_frame(img):
    img_rgb = np.array(img)

    # detect ball using red color channel
    r, g = img_rgb[:, :, 0].astype(int), img_rgb[:, :, 1].astype(int)
    red_mask = (r - g > 80) & (r > 150)
    red_u8 = red_mask.astype(np.uint8) * 255
    bx = by = bw = bh = np.nan
    for cnt in cv2.findContours(red_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]:
        if cv2.contourArea(cnt) > 5:
            bx, by, bw, bh = cv2.boundingRect(cnt); break
    
    # detect paddle by counting black pixels per row
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    gray[red_mask] = 255
    row_count = cv2.threshold(gray, 50, 1, cv2.THRESH_BINARY_INV)[1].sum(axis=1)
    paddle_rows = np.where((row_count > 4) & (row_count < 10))[0]
    py = np.min(paddle_rows)
    ph = np.max(paddle_rows) - py
    
    # convert paddle start and end to board height by assuming symmetric borders
    border_start = np.argmax(row_count)
    board_height = img_rgb.shape[0] - 2 * border_start
    paddle_y = 1 - float((py - border_start) / board_height) - float(ph / board_height)
    ball_x = float((bx - border_start) / board_height)
    ball_y = 1 - float((by - border_start) / board_height)
    
    # visualize segmentation results
    px = img_rgb.shape[1] - border_start
    if not np.isnan(bx):
        cv2.rectangle(img_rgb, (bx, by), (bx + bw, by + bh), (0, 0, 255), 2)
    cv2.rectangle(img_rgb, (px, py), (px + 6, py + ph), (0, 255, 0), 2)
    
    return ball_x, ball_y, paddle_y, img_rgb

Let’s test this on some frames from the Pong environment:

# generate some frames from a random policy, compare segmentation state to true states
env = pyRDDLGym.make('Pong_arcade', '0')
agent = RandomAgent(env.action_space)
frames = []
state, _ = env.reset()
for step in range(18):
    ball_x, ball_y, paddle_y, img_rgb = segment_frame(env.render(False))
    print(f"pred ball_x: {ball_x:.2f} vs test {state['ball-x___b1']:.2f}, "
          f"pred ball_y: {ball_y:.2f} vs test {state['ball-y___b1']:.2f}, "
          f"pred paddle_y: {paddle_y:.2f} vs test {state['paddle-y']:.2f}")
    frames.append(img_rgb)
    action = agent.sample_action(state)
    state, *_ = env.step(action)
env.close()
plt.close('all')

# plot segmented frames
%matplotlib inline
fig, axs = plt.subplots(ncols=6, nrows=3, figsize=(12, 6))
for fr, ax in zip(frames, axs.flatten()):
    ax.imshow(fr); ax.axis('off')
plt.tight_layout(); plt.show()
pred ball_x: 0.49 vs test 0.50, pred ball_y: 0.52 vs test 0.50, pred paddle_y: 0.39 vs test 0.40
pred ball_x: 0.52 vs test 0.53, pred ball_y: 0.53 vs test 0.51, pred paddle_y: 0.44 vs test 0.44
pred ball_x: 0.56 vs test 0.56, pred ball_y: 0.54 vs test 0.52, pred paddle_y: 0.48 vs test 0.48
pred ball_x: 0.59 vs test 0.59, pred ball_y: 0.55 vs test 0.53, pred paddle_y: 0.44 vs test 0.44
pred ball_x: 0.62 vs test 0.62, pred ball_y: 0.56 vs test 0.54, pred paddle_y: 0.44 vs test 0.44
pred ball_x: 0.64 vs test 0.65, pred ball_y: 0.57 vs test 0.55, pred paddle_y: 0.48 vs test 0.48
pred ball_x: 0.68 vs test 0.68, pred ball_y: 0.58 vs test 0.56, pred paddle_y: 0.48 vs test 0.48
pred ball_x: 0.71 vs test 0.71, pred ball_y: 0.59 vs test 0.57, pred paddle_y: 0.52 vs test 0.52
pred ball_x: 0.74 vs test 0.74, pred ball_y: 0.60 vs test 0.58, pred paddle_y: 0.48 vs test 0.48
pred ball_x: 0.77 vs test 0.77, pred ball_y: 0.61 vs test 0.59, pred paddle_y: 0.52 vs test 0.52
pred ball_x: 0.80 vs test 0.80, pred ball_y: 0.62 vs test 0.60, pred paddle_y: 0.48 vs test 0.48
pred ball_x: 0.83 vs test 0.83, pred ball_y: 0.63 vs test 0.61, pred paddle_y: 0.44 vs test 0.44
pred ball_x: 0.86 vs test 0.86, pred ball_y: 0.64 vs test 0.62, pred paddle_y: 0.48 vs test 0.48
pred ball_x: 0.89 vs test 0.89, pred ball_y: 0.65 vs test 0.63, pred paddle_y: 0.44 vs test 0.44
pred ball_x: 0.92 vs test 0.92, pred ball_y: 0.66 vs test 0.64, pred paddle_y: 0.39 vs test 0.40
pred ball_x: 0.95 vs test 0.95, pred ball_y: 0.67 vs test 0.65, pred paddle_y: 0.39 vs test 0.40
pred ball_x: 0.98 vs test 0.98, pred ball_y: 0.68 vs test 0.66, pred paddle_y: 0.44 vs test 0.44
pred ball_x: nan vs test 1.01, pred ball_y: nan vs test 0.67, pred paddle_y: 0.48 vs test 0.48
../_images/6b9163a3ac072da8e865d535d1e8fe6188be52d1d4dffc9efecf3f0c83191b5f.png

We also need to infer the ball velocities. To do this, we need two consecutive frames (we also vectorize the state for the planner):

def extract_state(prev_obs, next_obs):
    ball_x1, ball_y1, *_ = segment_frame(prev_obs)
    ball_x2, ball_y2, paddle_y, _ = segment_frame(next_obs)
    return {
        'ball-x': np.array([ball_x2]), 
        'ball-y': np.array([ball_y2]), 
        'vel-x': np.array([ball_x2 - ball_x1]), 
        'vel-y': np.array([ball_y2 - ball_y1]), 
        'paddle-y': paddle_y
    }

Control from Pixels with JaxPlan#

Let’s initialize and run the MPC controller in JaxPlan. We map pairs of consecutive images to states and run JaxPlan from these inferred states:

config_str = """
[Compiler]
print_warnings=False
[Planner]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.01}
batch_size_train=16
batch_size_test=16
rollout_horizon=40
[Optimize]
key=42
epochs=200
"""
planner_args, _, train_args = load_config_from_string(config_str)
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOnlineController(planner, print_summary=False, print_progress=False, **train_args)

Let’s run the controller:

# set visualizer
if not os.path.exists('frames'):
    os.makedirs('frames')
env = pyRDDLGym.make('Pong_arcade', '0', vectorized=True)
recorder = MovieGenerator("frames", "pong_pomdp", max_frames=150)
env.set_visualizer(viz=None, movie_gen=recorder)

# run control
env.reset(seed=0)
prev_obs = env.render()
for step in range(150):
    obs = env.render()
    state_pred = extract_state(prev_obs, obs)
    action = agent.sample_action(state_pred)
    env.step(action)
    prev_obs = obs
env.close()
plt.close('all')
Image(filename='frames/pong_pomdp_0.gif') 
../_images/8779b3194a671ae40b1f852428e6d6bb81dd545b00cb9e0a82fe325d043a15b6.gif