Closed-loop replanning with JaxPlan.#
This follow-up example provides another way to do closed-loop control in JaxPlan. Starting with the initial state of the system, optimize the action-fluents over a short lookahead horizon (e.g. 5 decision steps), then take the best immediate action from the plan and let the system evolve. Then repeat the process again, taking the best action from the new plan, and so on. This technique is called replanning in the planning literature, which is quite similar in flow to model-predictive control (MPC) where we obtain the dynamics model from the RDDL description.
Start by installing the required packages:
%pip install --quiet --upgrade pip
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym.git
%pip install --quiet git+https://github.com/pyrddlgym-project/rddlrepository.git
%pip install --quiet git+https://github.com/pyrddlgym-project/pyRDDLGym-jax.git
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from IPython.display import Image
import os
import pyRDDLGym
from pyRDDLGym.core.visualizer.movie import MovieGenerator
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOnlineController, load_config_from_string
We will again optimize the stochastic Wildfire problem from IPPC 2014, noting again the use of the vectorized
option:
env = pyRDDLGym.make('Wildfire_MDP_ippc2014', '1', vectorized=True)
The config file is similar to the open-loop examples, except we also specify the rollout_horizon
parameter to indicate how far ahead we search during optimization:
config = """
[Model]
comparison_kwargs={'weight': 100}
rounding_kwargs={'weight': 100}
control_kwargs={'weight': 100}
[Optimizer]
method='JaxStraightLinePlan'
optimizer_kwargs={'learning_rate': 0.1}
rollout_horizon=5
[Training]
policy_hyperparams={'cut-out': 5.0, 'put-out': 5.0}
"""
planner_args, _, train_args = load_config_from_string(config)
We now initialize and run our controller. We will set train_seconds
to 1 to indicate that we want to optimize for 1 second per decision time step:
planner = JaxBackpropPlanner(rddl=env.model, **planner_args)
agent = JaxOnlineController(planner, print_summary=False, train_seconds=1, **train_args)
Notice that no optimization is done before calling the evaluate
function, because the replanning method will only optimize when it actually begins interacting with the environment, e.g. it observes the current state, finds the best action, executes it in the environment, then waits for the state to transition and begins again.
Let’s assign a visualizer so we can keep track of the behavior of the planner in real time. Then we just call evaluate()
to actually do the planning:
if not os.path.exists('frames'):
os.makedirs('frames')
recorder = MovieGenerator("frames", "wildfire", max_frames=env.horizon)
env.set_visualizer(viz=None, movie_gen=recorder)
agent.evaluate(env, episodes=1, render=True)
env.close()
Image(filename='frames/wildfire_0.gif')
0 it / -356.603119 train / -115.781250 test / -115.781250 best / 5 status: 0%| | 0/100 [00:02<?, ?it/s]
1 it / -84.883430 train / -75.000000 test / -75.000000 best / 5 status: 1%|▏ | 1/100 [00:01<02:15, 1.37s/it]
400 it / -26.407290 train / -55.359375 test / -33.406250 best / 5 status: : 400it [00:01, 380.95it/s]
408 it / -16.250628 train / -4.750000 test / -1.546875 best / 5 status: : 408it [00:01, 384.90it/s]
407 it / -16.250628 train / -4.750000 test / -1.546875 best / 5 status: : 407it [00:01, 389.44it/s]
412 it / -8.496055 train / -3.703125 test / -1.546875 best / 5 status: : 412it [00:01, 392.78it/s]
400 it / -2.656929 train / -3.875000 test / -1.546875 best / 5 status: : 400it [00:01, 376.49it/s]
417 it / -2.344416 train / -3.640625 test / -0.625000 best / 5 status: : 417it [00:01, 394.45it/s]
417 it / -0.469423 train / -4.234375 test / -0.312500 best / 5 status: : 417it [00:01, 394.43it/s]
412 it / -5.482677 train / -2.828125 test / -1.546875 best / 5 status: : 412it [00:01, 386.64it/s]
437 it / -0.955009 train / -5.593750 test / -1.546875 best / 5 status: : 437it [00:01, 410.00it/s]
423 it / -6.719067 train / -4.671875 test / -0.781250 best / 5 status: : 423it [00:01, 398.14it/s]
439 it / -3.438159 train / -8.953125 test / -1.546875 best / 5 status: : 439it [00:01, 410.98it/s]
418 it / -2.182323 train / -6.296875 test / -1.546875 best / 5 status: : 418it [00:01, 398.11it/s]
437 it / -3.438159 train / -8.953125 test / -1.546875 best / 5 status: : 437it [00:01, 409.82it/s]
422 it / -16.717945 train / -3.703125 test / -1.546875 best / 5 status: : 422it [00:01, 396.19it/s]
437 it / -13.177738 train / -8.250000 test / -1.546875 best / 5 status: : 437it [00:01, 412.76it/s]
421 it / -8.124269 train / -3.031250 test / -1.546875 best / 5 status: : 421it [00:01, 395.86it/s]
418 it / -1.875669 train / -4.265625 test / -0.625000 best / 5 status: : 418it [00:01, 396.90it/s]
417 it / -1.875669 train / -4.265625 test / -0.703125 best / 5 status: : 417it [00:01, 398.09it/s]
426 it / -1.406957 train / -5.609375 test / -0.625000 best / 5 status: : 426it [00:01, 397.19it/s]
442 it / -0.500641 train / -3.781250 test / -0.781250 best / 5 status: : 442it [00:01, 410.73it/s]
407 it / -2.500665 train / -6.078125 test / -1.546875 best / 5 status: : 407it [00:01, 386.28it/s]
406 it / -2.500665 train / -6.078125 test / -1.546875 best / 5 status: : 406it [00:01, 384.35it/s]
393 it / -79.372299 train / -48.687500 test / -28.609375 best / 5 status: : 393it [00:01, 369.21it/s]
409 it / -4.375327 train / -4.328125 test / -0.843750 best / 5 status: : 409it [00:01, 385.14it/s]
411 it / -11.563165 train / -3.093750 test / -0.843750 best / 5 status: : 411it [00:01, 391.02it/s]
417 it / -5.789439 train / -3.515625 test / -0.843750 best / 5 status: : 417it [00:01, 390.76it/s]
398 it / -52.657360 train / -60.125000 test / -35.843750 best / 5 status: : 398it [00:01, 368.87it/s]
406 it / -0.156924 train / -2.531250 test / -0.406250 best / 5 status: : 406it [00:01, 390.79it/s]
418 it / -4.531907 train / -1.562500 test / -0.312500 best / 5 status: : 418it [00:01, 392.07it/s]
413 it / -4.539444 train / -1.250000 test / -0.406250 best / 5 status: : 413it [00:01, 393.52it/s]
409 it / -1.064004 train / -1.062500 test / -0.406250 best / 5 status: : 409it [00:01, 388.82it/s]
424 it / -0.781922 train / -1.312500 test / -0.406250 best / 5 status: : 424it [00:01, 395.99it/s]
425 it / -0.000675 train / -3.359375 test / -0.406250 best / 5 status: : 425it [00:01, 401.94it/s]
409 it / -4.539444 train / -1.250000 test / -0.000000 best / 5 status: : 409it [00:01, 389.00it/s]
421 it / -0.781922 train / -1.312500 test / -0.406250 best / 5 status: : 421it [00:01, 402.78it/s]
419 it / -2.188166 train / -1.484375 test / -0.406250 best / 5 status: : 419it [00:01, 399.86it/s]
419 it / -0.781922 train / -1.312500 test / -0.406250 best / 5 status: : 419it [00:01, 397.31it/s]
428 it / -4.844400 train / -2.765625 test / -0.406250 best / 5 status: : 428it [00:01, 399.16it/s]