Training a PPO policy with rllib.

Training a PPO policy with rllib.#

We demonstrate how to train a PPO policy using the rllib package.

First install and import the required packages:

%pip install --quiet --upgrade pip
%pip install --quiet -U "ray[rllib]==2.37.0"
%pip install --quiet git+
%pip install --quiet git+
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import warnings
import os
from IPython.display import Image

from ray.tune.registry import register_env
from ray.rllib.algorithms.ppo import PPOConfig

import pyRDDLGym
from import MovieGenerator

from pyRDDLGym_rl.core.agent import RLLibAgent
from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv

We will optimize the classical CartPole control problem. In this case, the environment creation has to be wrapped in an outside function as follows, and the observation space needs to be flattened:

def env_creator(env_config):
    return pyRDDLGym.make(env_config['domain'], env_config['instance'], base_class=SimplifiedActionRDDLEnv) 

register_env('RLLibEnv', env_creator)

Let’s set up and train a PPO agent:

# set up the agent
config = PPOConfig()
config = config.env_runners(num_env_runners=1, num_envs_per_env_runner=8)
config = config.environment('RLLibEnv', env_config={'domain': 'Reservoir_ippc2023', 'instance': '1'})
config =, lr=0.0003, gamma=0.98, lambda_=0.5)
algo =

# train the agent
for n in range(100):
    result = algo.train()
    if n % 10 == 0: print(f'iteration {n}, mean return {result["env_runners"]["episode_reward_mean"]}')
C:\Python\envs\rddlnb\Lib\site-packages\ray\rllib\algorithms\ RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
C:\Python\envs\rddlnb\Lib\site-packages\ray\tune\logger\ RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
C:\Python\envs\rddlnb\Lib\site-packages\ray\tune\logger\ RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
C:\Python\envs\rddlnb\Lib\site-packages\ray\tune\logger\ RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-11-24 11:42:14,073	INFO -- Started a local Ray instance.
(RolloutWorker pid=29552) C:\Python\envs\rddlnb\Lib\site-packages\pyRDDLGym\core\debug\ UserWarning: State invariant 3 does not have a structure of <action or state fluent> <op> <rhs>, where <op> is one of {<=, <, >=, >} and <rhs> is a deterministic function of non-fluents only, and will be ignored.
(RolloutWorker pid=29552) >> ( sum_{?r: reservoir} [ CONNECTED_TO_SEA(?r) ] ) == 1
(RolloutWorker pid=29552)   warnings.warn(message)
(RolloutWorker pid=29552) C:\Python\envs\rddlnb\Lib\site-packages\gymnasium\spaces\ UserWarning: WARN: Box bound precision lowered by casting to float32
(RolloutWorker pid=29552)   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
(RolloutWorker pid=29552) C:\Python\envs\rddlnb\Lib\site-packages\pyRDDLGym\core\debug\ UserWarning: State invariant 3 does not have a structure of <action or state fluent> <op> <rhs>, where <op> is one of {<=, <, >=, >} and <rhs> is a deterministic function of non-fluents only, and will be ignored.
(RolloutWorker pid=29552) >> ( sum_{?r: reservoir} [ CONNECTED_TO_SEA(?r) ] ) == 1
(RolloutWorker pid=29552)   warnings.warn(message)
(RolloutWorker pid=29552) C:\Python\envs\rddlnb\Lib\site-packages\gymnasium\spaces\ UserWarning: WARN: Box bound precision lowered by casting to float32
(RolloutWorker pid=29552)   gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
2024-11-24 11:42:25,420	INFO -- Trainable.setup took 14.209 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
2024-11-24 11:42:25,422	WARNING -- Install gputil for GPU system monitoring.
2024-11-24 11:42:29,767	WARNING -- DeprecationWarning: `ray.rllib.execution.train_ops.multi_gpu_train_one_step` has been deprecated. This will raise an error in the future!
iteration 0, mean return -42201.1226705054
iteration 10, mean return -34347.93848920746
iteration 20, mean return -27079.48502499436
iteration 30, mean return -11586.812786216222
iteration 40, mean return -5314.710062150564
iteration 50, mean return -1197.3102556924396
iteration 60, mean return -390.2064510148477
iteration 70, mean return -425.92802183586775
iteration 80, mean return -43.05392208331204
iteration 90, mean return -232.74436257098216

To evaluate the trained agent, we wrap it in a RLLibAgent wrapper, which is an instance of pyRDDLGym’s BaseAgent:

agent = RLLibAgent(algo)

Lastly, we evaluate the agent as always:

# for recording movies
if not os.path.exists('frames'):
env = env_creator({'domain': 'Reservoir_ippc2023', 'instance': '1'})
recorder = MovieGenerator("frames", "reservoir_rllib", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)

print(agent.evaluate(env, episodes=1, verbose=False, render=True))
{'mean': -275.78113195599644, 'median': -275.78113195599644, 'min': -275.78113195599644, 'max': -275.78113195599644, 'std': 0.0}