Calling External PyTorch Modules in pyRDDLGym.#
This advanced notebook discusses how to execute external function calls from RDDL domain description files, using torch2jax to wrap torch models to JAX and optimize them with JaxPlan.
%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
%pip install --quiet -U torch torch2jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Import the required packages:
import torch
import torch.nn as nn
from torch2jax import t2j
import jax
import pyRDDLGym
from rddlrepository.core.manager import RDDLRepoManager
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController, load_config_from_string
Let’s define the reservoir rlevel CPF as a pytorch model. In principle, we can use any pre-trained torch module as a plug in estimator.
class RlevelModel(nn.Module):
def forward(self, rlevel, inflow, rain, evaporated, overflow, released_water, TOP_RES):
base = rlevel + inflow + rain - evaporated - overflow - released_water
# torch.maximum/minimum are not currently implemented by torch2jax Torchish tensors.
# Use algebraic equivalents based on abs that preserve differentiability almost everywhere.
base = 0.5 * (base + torch.abs(base))
diff = TOP_RES - base
base = TOP_RES - 0.5 * (diff + torch.abs(diff))
return base
rlevel_model = RlevelModel()
Next, we use torch2jax to wrap this model as JAX code, which allows this function to take advantage of JAX vmap, scan and other JAX operations:
rlevel_model = t2j(rlevel_model.eval())
external_rlevel_function = jax.jit(rlevel_model)
Next, we create a modified reservoir problem where the rlevel(?r) CPF is computed externally from the pytorch model:
domain_text = """
domain reservoir_control_cont {
types {
reservoir: object;
};
pvariables {
// Constants
TOP_RES(reservoir): { non-fluent, real, default = 100.0 }; // Overflowing amount
MAX_LEVEL(reservoir): { non-fluent, real, default = 80.0 }; // The upper bound for desired reservoir level
MIN_LEVEL(reservoir): { non-fluent, real, default = 20.0 }; // The lower bound for desired reservoir level
RAIN_VAR(reservoir): { non-fluent, real, default = 5.0 }; // Half normal variance parameter for rainfall
RES_CONNECT(reservoir, reservoir): { non-fluent, bool, default = false }; // Indicates 2nd reservoir is forward connected to 1st reservoir
EVAPORATION_FACTOR: { non-fluent, real, default = 0.05 }; // Maximum fraction of evaporation
CONNECTED_TO_SEA(reservoir): {non-fluent, bool, default = false}; // reservoirs connected to the sea
LOW_COST(reservoir) : { non-fluent, real, default = -5.0 }; // Penalty per unit of level < MIN_LEVEL
HIGH_COST(reservoir): { non-fluent, real, default = -10.0 }; // Penalty per unit of level > MAX_LEVEL
OVERFLOW_COST(reservoir): { non-fluent, real, default = -15.0 }; // Penalty per unit of level > TOP_RES
// Intermediate fluents
rain(reservoir): {interm-fluent, real}; // Amount of rain fell
evaporated(reservoir): {interm-fluent, real}; // Evaporated water from reservoir
overflow(reservoir): {interm-fluent, real}; // Excess overflow (over the rim)
inflow(reservoir): {interm-fluent, real}; // Amount received from backward reservoirs
individual_outflow(reservoir): {interm-fluent, real}; // Net amount of water released from reservoir to individually connected reservoirs
released_water(reservoir): {interm-fluent, real}; // Actual amount of water released (with action clipping
// State fluents
rlevel(reservoir): {state-fluent, real, default = 50.0 }; // Reservoir level
// Action fluents
release(reservoir): { action-fluent, real, default = 0.0 }; // Action to set outflow of reservoir
};
cpfs {
rain(?r) = abs[Normal(0, RAIN_VAR(?r))];
evaporated(?r) = EVAPORATION_FACTOR * rlevel(?r) / TOP_RES(?r);
released_water(?r) = max[0, min[rlevel(?r), release(?r)]];
overflow(?r) = max[0, rlevel(?r) - released_water(?r) - TOP_RES(?r)];
individual_outflow(?r) = released_water(?r)* 1 / ((sum_{?out: reservoir} [RES_CONNECT(?r,?out)]) + CONNECTED_TO_SEA(?r));
inflow(?r) = (sum_{?in : reservoir} [RES_CONNECT(?in,?r) * individual_outflow(?in)]);
rlevel'(?r) = $PyTorchModel[?r](rlevel(_), inflow(_), rain(_), evaporated(_), overflow(_), released_water(_), TOP_RES(_));
};
reward = (sum_{?r: reservoir} [if ((rlevel'(?r) >= MIN_LEVEL(?r)) ^ (rlevel'(?r) <= MAX_LEVEL(?r)))
then 0
else if (rlevel'(?r) <= MIN_LEVEL(?r))
then LOW_COST(?r) * (MIN_LEVEL(?r) - rlevel'(?r))
else if ((rlevel'(?r) > MAX_LEVEL(?r)) ^ (rlevel'(?r) <= TOP_RES(?r)))
then HIGH_COST(?r) * (rlevel'(?r) - MAX_LEVEL(?r))
else HIGH_COST(?r) * (rlevel'(?r) - MAX_LEVEL(?r)) + OVERFLOW_COST(?r) * overflow(?r)
]);
}
"""
instance_text = """
non-fluents Reservoir_3nf {
domain = reservoir_control_dis;
objects{
reservoir: {t1, t2, t3};
};
non-fluents {
RES_CONNECT(t1,t3);
RES_CONNECT(t2,t3);
CONNECTED_TO_SEA(t3);
};
}
instance Reservoir_cont_0 {
domain = reservoir_control_dis;
non-fluents = Reservoir_3nf;
init-state{
rlevel(t1) = 45.0;
};
max-nondef-actions = pos-inf;
horizon = 120;
discount = 1.0;
}
"""
Now we register this domain with rddlrepository as usual:
manager = RDDLRepoManager(rebuild=True)
manager.register_context("torch")
manager.register_domain("Reservoir", "torch", domain_text, desc="a reservoir torch problem", viz=None)
manager.get_problem("Reservoir_torch").register_instance("0", instance_text)
_ = RDDLRepoManager(rebuild=True)
Context <torch> was successfully registered in rddlrepository.
Domain <Reservoir> was successfully registered in rddlrepository with context <torch>.
Instance <0> was successfully registered in rddlrepository for domain <Reservoir_torch>.
Finally, to instantiate the domain, we must explicitly provide the PyTorchModel defined earlier to the environmentL
env = pyRDDLGym.make('Reservoir_torch', '0', vectorized=True, backend_kwargs={
'python_functions': {'PyTorchModel': external_rlevel_function}})
c:\Users\mgime\anaconda3\envs\jaxenv\Lib\site-packages\gymnasium\spaces\box.py:235: UserWarning: [33mWARN: Box low's precision lowered by casting to float32, current low.dtype=float64[0m
gym.logger.warn(
c:\Users\mgime\anaconda3\envs\jaxenv\Lib\site-packages\gymnasium\spaces\box.py:305: UserWarning: [33mWARN: Box high's precision lowered by casting to float32, current high.dtype=float64[0m
gym.logger.warn(
Let’s run the JAX planner:
planner = JaxBackpropPlanner(rddl=env.model, plan=JaxStraightLinePlan(),
python_functions={'PyTorchModel': external_rlevel_function})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=60)
print(agent.evaluate(env, episodes=1))
[90m[INFO] Compiler will cast pvars {'CONNECTED_TO_SEA', 'RES_CONNECT'} to float.[0m
[90m[INFO] Bounds of action-fluent <release> set to (array([-inf, -inf, -inf], dtype=float32), array([inf, inf, inf], dtype=float32)).[0m
[33m[WARN] policy_hyperparams is not set: setting values to 1.0 for all action-fluents, which could be suboptimal.[0m
720 it | -398.53857 train | -569.07648 test | -298.56113 best | 2 pgpe | 5 status: 100%|██████████| 00:58 , 12.01it/s
{'mean': np.float64(-1301.2390518188477), 'median': np.float64(-1301.2390518188477), 'min': np.float64(-1301.2390518188477), 'max': np.float64(-1301.2390518188477), 'std': np.float64(0.0)}