Calling External PyTorch Modules in pyRDDLGym.

Calling External PyTorch Modules in pyRDDLGym.#

This advanced notebook discusses how to execute external function calls from RDDL domain description files, using torch2jax to wrap torch models to JAX and optimize them with JaxPlan.

%pip install --quiet --upgrade pip
%pip install --quiet pyRDDLGym rddlrepository pyRDDLGym-jax
%pip install --quiet -U torch torch2jax
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

Import the required packages:

import torch
import torch.nn as nn
from torch2jax import t2j
import jax

import pyRDDLGym
from rddlrepository.core.manager import RDDLRepoManager
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController, load_config_from_string

Let’s define the reservoir rlevel CPF as a pytorch model. In principle, we can use any pre-trained torch module as a plug in estimator.

class RlevelModel(nn.Module):
    def forward(self, rlevel, inflow, rain, evaporated, overflow, released_water, TOP_RES):
        base = rlevel + inflow + rain - evaporated - overflow - released_water
        
        # torch.maximum/minimum are not currently implemented by torch2jax Torchish tensors.
        # Use algebraic equivalents based on abs that preserve differentiability almost everywhere.
        base = 0.5 * (base + torch.abs(base))
        diff = TOP_RES - base
        base = TOP_RES - 0.5 * (diff + torch.abs(diff))
        return base

rlevel_model = RlevelModel()

Next, we use torch2jax to wrap this model as JAX code, which allows this function to take advantage of JAX vmap, scan and other JAX operations:

rlevel_model = t2j(rlevel_model.eval())
external_rlevel_function = jax.jit(rlevel_model)

Next, we create a modified reservoir problem where the rlevel(?r) CPF is computed externally from the pytorch model:

domain_text = """
domain reservoir_control_cont {

	types {
		reservoir: object;
	};

    pvariables {

		// Constants
        TOP_RES(reservoir): { non-fluent, real, default = 100.0 }; // Overflowing amount
        MAX_LEVEL(reservoir): { non-fluent, real, default = 80.0 };  // The upper bound for desired reservoir level
		MIN_LEVEL(reservoir): { non-fluent, real, default = 20.0 };  // The lower bound for desired reservoir level
		RAIN_VAR(reservoir):  { non-fluent, real, default = 5.0 };  // Half normal variance parameter for rainfall
        RES_CONNECT(reservoir, reservoir): { non-fluent, bool, default = false }; // Indicates 2nd reservoir is forward connected to 1st reservoir
        EVAPORATION_FACTOR: { non-fluent, real, default = 0.05 }; // Maximum fraction of evaporation
        CONNECTED_TO_SEA(reservoir): {non-fluent, bool, default = false}; // reservoirs connected to the sea

        LOW_COST(reservoir) : { non-fluent, real, default =  -5.0 }; // Penalty per unit of level < MIN_LEVEL
		HIGH_COST(reservoir): { non-fluent, real, default = -10.0 }; // Penalty per unit of level > MAX_LEVEL
        OVERFLOW_COST(reservoir): { non-fluent, real, default = -15.0 }; // Penalty per unit of level > TOP_RES

        // Intermediate fluents
        rain(reservoir):   {interm-fluent, real}; // Amount of rain fell
        evaporated(reservoir): {interm-fluent, real}; // Evaporated water from reservoir
        overflow(reservoir):   {interm-fluent, real}; // Excess overflow (over the rim)
        inflow(reservoir):     {interm-fluent, real}; // Amount received from backward reservoirs
        individual_outflow(reservoir): {interm-fluent, real}; // Net amount of water released from reservoir to individually connected reservoirs
        released_water(reservoir): {interm-fluent, real}; // Actual amount of water released (with action clipping

        // State fluents
        rlevel(reservoir): {state-fluent, real, default = 50.0 }; // Reservoir level

        // Action fluents
        release(reservoir): { action-fluent, real, default = 0.0 }; // Action to set outflow of reservoir
    };

    cpfs {
        rain(?r) =  abs[Normal(0, RAIN_VAR(?r))];
		evaporated(?r) = EVAPORATION_FACTOR * rlevel(?r) / TOP_RES(?r);
		released_water(?r) = max[0, min[rlevel(?r), release(?r)]];
        overflow(?r) = max[0, rlevel(?r) - released_water(?r) - TOP_RES(?r)];
        individual_outflow(?r) = released_water(?r)* 1 / ((sum_{?out: reservoir} [RES_CONNECT(?r,?out)]) + CONNECTED_TO_SEA(?r));
        inflow(?r) = (sum_{?in : reservoir} [RES_CONNECT(?in,?r) * individual_outflow(?in)]);
        rlevel'(?r) = $PyTorchModel[?r](rlevel(_), inflow(_), rain(_), evaporated(_), overflow(_), released_water(_), TOP_RES(_));
    };

    reward = (sum_{?r: reservoir} [if ((rlevel'(?r) >= MIN_LEVEL(?r)) ^ (rlevel'(?r) <= MAX_LEVEL(?r)))
		then 0
		else if (rlevel'(?r) <= MIN_LEVEL(?r))
			then LOW_COST(?r) * (MIN_LEVEL(?r) - rlevel'(?r))
			else if ((rlevel'(?r) > MAX_LEVEL(?r)) ^ (rlevel'(?r) <= TOP_RES(?r)))
					then HIGH_COST(?r) * (rlevel'(?r) - MAX_LEVEL(?r))
					else HIGH_COST(?r) * (rlevel'(?r) - MAX_LEVEL(?r)) + OVERFLOW_COST(?r) * overflow(?r)
		]);
}
"""

instance_text = """
non-fluents Reservoir_3nf {
	domain = reservoir_control_dis;
	objects{
		reservoir: {t1, t2, t3};
	};
	non-fluents {
		RES_CONNECT(t1,t3);
		RES_CONNECT(t2,t3);
		CONNECTED_TO_SEA(t3);
	};
}

instance Reservoir_cont_0 {
	domain = reservoir_control_dis;
	non-fluents = Reservoir_3nf;
	init-state{
		rlevel(t1) = 45.0;
	};
	max-nondef-actions = pos-inf;
	horizon = 120;
	discount = 1.0;
}
"""

Now we register this domain with rddlrepository as usual:

manager = RDDLRepoManager(rebuild=True)
manager.register_context("torch")
manager.register_domain("Reservoir", "torch", domain_text, desc="a reservoir torch problem", viz=None)
manager.get_problem("Reservoir_torch").register_instance("0", instance_text)
_ = RDDLRepoManager(rebuild=True)
Context <torch> was successfully registered in rddlrepository.
Domain <Reservoir> was successfully registered in rddlrepository with context <torch>.
Instance <0> was successfully registered in rddlrepository for domain <Reservoir_torch>.

Finally, to instantiate the domain, we must explicitly provide the PyTorchModel defined earlier to the environmentL

env = pyRDDLGym.make('Reservoir_torch', '0', vectorized=True, backend_kwargs={
    'python_functions': {'PyTorchModel': external_rlevel_function}})
c:\Users\mgime\anaconda3\envs\jaxenv\Lib\site-packages\gymnasium\spaces\box.py:235: UserWarning: WARN: Box low's precision lowered by casting to float32, current low.dtype=float64
  gym.logger.warn(
c:\Users\mgime\anaconda3\envs\jaxenv\Lib\site-packages\gymnasium\spaces\box.py:305: UserWarning: WARN: Box high's precision lowered by casting to float32, current high.dtype=float64
  gym.logger.warn(

Let’s run the JAX planner:

planner = JaxBackpropPlanner(rddl=env.model, plan=JaxStraightLinePlan(),
                             python_functions={'PyTorchModel': external_rlevel_function})
agent = JaxOfflineController(planner, print_summary=False, train_seconds=60)
print(agent.evaluate(env, episodes=1))
[INFO] Compiler will cast pvars {'CONNECTED_TO_SEA', 'RES_CONNECT'} to float.
[INFO] Bounds of action-fluent <release> set to (array([-inf, -inf, -inf], dtype=float32), array([inf, inf, inf], dtype=float32)).
[WARN] policy_hyperparams is not set: setting values to 1.0 for all action-fluents, which could be suboptimal.
 720 it |    -398.53857 train |    -569.07648 test |    -298.56113 best | 2 pgpe | 5 status: 100%|██████████| 00:58 , 12.01it/s
{'mean': np.float64(-1301.2390518188477), 'median': np.float64(-1301.2390518188477), 'min': np.float64(-1301.2390518188477), 'max': np.float64(-1301.2390518188477), 'std': np.float64(0.0)}