mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Add DTTorchPolicy (#27889)
This commit is contained in:
parent
4692e8d802
commit
9330d8f244
3 changed files with 1095 additions and 0 deletions
|
@ -921,6 +921,13 @@ py_test(
|
|||
srcs = ["algorithms/dt/tests/test_dt_model.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_dt_policy",
|
||||
tags = ["team:rllib", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/dt/tests/test_dt_policy.py"]
|
||||
)
|
||||
|
||||
# ES
|
||||
py_test(
|
||||
name = "test_es",
|
||||
|
|
581
rllib/algorithms/dt/dt_torch_policy.py
Normal file
581
rllib/algorithms/dt/dt_torch_policy.py
Normal file
|
@ -0,0 +1,581 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
Optional,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import tree
|
||||
from gym.spaces import Discrete, Box
|
||||
|
||||
from ray.rllib.algorithms.dt.dt_torch_model import DTTorchModel
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.mingpt import configure_gpt_optimizer
|
||||
from ray.rllib.models.torch.torch_action_dist import (
|
||||
TorchDistributionWrapper,
|
||||
TorchCategorical,
|
||||
TorchDeterministic,
|
||||
)
|
||||
from ray.rllib.policy.torch_mixins import LearningRateSchedule
|
||||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import (
|
||||
TrainerConfigDict,
|
||||
TensorType,
|
||||
TensorStructType,
|
||||
TensorShape,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation import Episode # noqa
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = nn.functional
|
||||
|
||||
|
||||
class DTTorchPolicy(LearningRateSchedule, TorchPolicyV2):
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict,
|
||||
):
|
||||
LearningRateSchedule.__init__(
|
||||
self,
|
||||
config["lr"],
|
||||
config["lr_schedule"],
|
||||
)
|
||||
|
||||
TorchPolicyV2.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
)
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def make_model_and_action_dist(
|
||||
self,
|
||||
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
||||
# Model
|
||||
model_config = self.config["model"]
|
||||
# TODO: make these better with better AlgorithmConfig options.
|
||||
model_config.update(
|
||||
embed_dim=self.config["embed_dim"],
|
||||
max_ep_len=self.config["horizon"],
|
||||
num_layers=self.config["num_layers"],
|
||||
num_heads=self.config["num_heads"],
|
||||
embed_pdrop=self.config["embed_pdrop"],
|
||||
resid_pdrop=self.config["resid_pdrop"],
|
||||
attn_pdrop=self.config["attn_pdrop"],
|
||||
use_obs_output=self.config.get("loss_coef_obs", 0) > 0,
|
||||
use_return_output=self.config.get("loss_coef_returns_to_go", 0) > 0,
|
||||
)
|
||||
|
||||
num_outputs = int(np.product(self.observation_space.shape))
|
||||
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=model_config,
|
||||
framework=self.config["framework"],
|
||||
model_interface=None,
|
||||
default_model=DTTorchModel,
|
||||
name="model",
|
||||
)
|
||||
|
||||
# Action Distribution
|
||||
if isinstance(self.action_space, Discrete):
|
||||
action_dist = TorchCategorical
|
||||
elif isinstance(self.action_space, Box):
|
||||
action_dist = TorchDeterministic
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model, action_dist
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def optimizer(
|
||||
self,
|
||||
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
|
||||
optimizer = configure_gpt_optimizer(
|
||||
model=self.model,
|
||||
learning_rate=self.config["lr"],
|
||||
weight_decay=self.config["optimizer"]["weight_decay"],
|
||||
betas=self.config["optimizer"]["betas"],
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def postprocess_trajectory(
|
||||
self,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
||||
episode: Optional["Episode"] = None,
|
||||
) -> SampleBatch:
|
||||
"""Called by offline data reader after loading in one episode.
|
||||
Adds a done flag at the end of trajectory so that SegmentationBuffer can
|
||||
split using the done flag to avoid duplicate trajectories.
|
||||
"""
|
||||
ep_len = sample_batch.env_steps()
|
||||
sample_batch[SampleBatch.DONES] = np.array([False] * (ep_len - 1) + [True])
|
||||
return sample_batch
|
||||
|
||||
@PublicAPI
|
||||
def get_initial_input_dict(self, observation: TensorStructType) -> SampleBatch:
|
||||
"""Get the initial input_dict to be passed into compute_single_action.
|
||||
|
||||
Args:
|
||||
observation: first (unbatched) observation from env.reset()
|
||||
|
||||
Returns:
|
||||
The input_dict for inference: {
|
||||
OBS: [max_seq_len, obs_dim] array,
|
||||
ACTIONS: [max_seq_len - 1, act_dim] array,
|
||||
RETURNS_TO_GO: [max_seq_len - 1] array,
|
||||
REWARDS: scalar,
|
||||
TIMESTEPS: [max_seq_len - 1] array,
|
||||
}
|
||||
Note the sequence lengths are different, and is specified as per
|
||||
view_requirements. Explanations in action_distribution_fn method.
|
||||
"""
|
||||
observation = convert_to_numpy(observation)
|
||||
obs_shape = observation.shape
|
||||
obs_dtype = observation.dtype
|
||||
|
||||
act_shape = self.action_space.shape
|
||||
act_dtype = self.action_space.dtype
|
||||
|
||||
# Here we will pad all the required inputs to its proper sequence length
|
||||
# as their ViewRequirement.
|
||||
|
||||
observations = np.concatenate(
|
||||
[
|
||||
np.zeros((self.max_seq_len - 1, *obs_shape), dtype=obs_dtype),
|
||||
observation[None],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
actions = np.zeros((self.max_seq_len - 1, *act_shape), dtype=act_dtype)
|
||||
|
||||
rtg = np.zeros(self.max_seq_len - 1, dtype=np.float32)
|
||||
|
||||
rewards = np.zeros((), dtype=np.float32)
|
||||
|
||||
# -1 for masking in action_distribution_fn
|
||||
timesteps = np.full(self.max_seq_len - 1, fill_value=-1, dtype=np.int32)
|
||||
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: observations,
|
||||
SampleBatch.ACTIONS: actions,
|
||||
SampleBatch.RETURNS_TO_GO: rtg,
|
||||
SampleBatch.REWARDS: rewards,
|
||||
SampleBatch.T: timesteps,
|
||||
}
|
||||
)
|
||||
return input_dict
|
||||
|
||||
@PublicAPI
|
||||
def get_next_input_dict(
|
||||
self,
|
||||
input_dict: SampleBatch,
|
||||
action: TensorStructType,
|
||||
reward: TensorStructType,
|
||||
next_obs: TensorStructType,
|
||||
extra: Dict[str, TensorType],
|
||||
) -> SampleBatch:
|
||||
"""Returns a new input_dict after stepping through the environment once.
|
||||
|
||||
Args:
|
||||
input_dict: the input dict passed into compute_single_action.
|
||||
action: the (unbatched) action taken this step.
|
||||
reward: the (unbatched) reward from env.step
|
||||
next_obs: the (unbatached) next observation from env.step
|
||||
extra: the extra action out from compute_single_action.
|
||||
In this case contains current returns to go *before* the current
|
||||
reward is subtracted from target_return.
|
||||
|
||||
Returns:
|
||||
A new input_dict to be passed into compute_single_action.
|
||||
The input_dict for inference: {
|
||||
OBS: [max_seq_len, obs_dim] array,
|
||||
ACTIONS: [max_seq_len - 1, act_dim] array,
|
||||
RETURNS_TO_GO: [max_seq_len - 1] array,
|
||||
REWARDS: scalar,
|
||||
TIMESTEPS: [max_seq_len - 1] array,
|
||||
}
|
||||
Note the sequence lengths are different, and is specified as per
|
||||
view_requirements. Explanations in action_distribution_fn method.
|
||||
"""
|
||||
# creates a copy of input_dict with only numpy arrays
|
||||
input_dict = tree.map_structure(convert_to_numpy, input_dict)
|
||||
# convert everything else to numpy as well
|
||||
action, reward, next_obs, extra = convert_to_numpy(
|
||||
(action, reward, next_obs, extra)
|
||||
)
|
||||
|
||||
# check dimensions
|
||||
assert input_dict[SampleBatch.OBS].shape == (
|
||||
self.max_seq_len,
|
||||
*self.observation_space.shape,
|
||||
)
|
||||
assert input_dict[SampleBatch.ACTIONS].shape == (
|
||||
self.max_seq_len - 1,
|
||||
*self.action_space.shape,
|
||||
)
|
||||
assert input_dict[SampleBatch.RETURNS_TO_GO].shape == (self.max_seq_len - 1,)
|
||||
assert input_dict[SampleBatch.T].shape == (self.max_seq_len - 1,)
|
||||
|
||||
# Shift observations
|
||||
input_dict[SampleBatch.OBS] = np.concatenate(
|
||||
[
|
||||
input_dict[SampleBatch.OBS][1:],
|
||||
next_obs[None],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Shift actions
|
||||
input_dict[SampleBatch.ACTIONS] = np.concatenate(
|
||||
[
|
||||
input_dict[SampleBatch.ACTIONS][1:],
|
||||
action[None],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Reward is not a sequence, it's only used to calculate next rtg.
|
||||
input_dict[SampleBatch.REWARDS] = np.asarray(reward)
|
||||
|
||||
# See methods action_distribution_fn and extra_action_out for an explanation
|
||||
# of why this is done.
|
||||
input_dict[SampleBatch.RETURNS_TO_GO] = np.concatenate(
|
||||
[
|
||||
input_dict[SampleBatch.RETURNS_TO_GO][1:],
|
||||
np.asarray(extra[SampleBatch.RETURNS_TO_GO])[None],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Shift and increment timesteps
|
||||
input_dict[SampleBatch.T] = np.concatenate(
|
||||
[
|
||||
input_dict[SampleBatch.T][1:],
|
||||
input_dict[SampleBatch.T][-1:] + 1,
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
return input_dict
|
||||
|
||||
@DeveloperAPI
|
||||
def get_initial_rtg_tensor(
|
||||
self,
|
||||
shape: TensorShape,
|
||||
dtype: Optional[Type] = torch.float32,
|
||||
device: Optional["torch.device"] = None,
|
||||
):
|
||||
"""Returns a initial/target returns-to-go tensor of the given shape.
|
||||
|
||||
Args:
|
||||
shape: Shape of the rtg tensor.
|
||||
dtype: Type of the data in the tensor. Defaults to torch.float32.
|
||||
device: The device this tensor should be on. Defaults to self.device.
|
||||
"""
|
||||
if device is None:
|
||||
device = self.device
|
||||
if dtype is None:
|
||||
device = torch.float32
|
||||
|
||||
assert self.config["target_return"] is not None, "Must specify target_return."
|
||||
initial_rtg = torch.full(
|
||||
shape,
|
||||
fill_value=self.config["target_return"],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
return initial_rtg
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
@DeveloperAPI
|
||||
def compute_actions(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
||||
raise ValueError("Please use compute_actions_from_input_dict instead.")
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
input_dict: SampleBatch,
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""
|
||||
Args:
|
||||
input_dict: input_dict (that contains a batch dimension for each value).
|
||||
Keys and shapes: {
|
||||
OBS: [batch_size, max_seq_len, obs_dim],
|
||||
ACTIONS: [batch_size, max_seq_len - 1, act_dim],
|
||||
RETURNS_TO_GO: [batch_size, max_seq_len - 1],
|
||||
REWARDS: [batch_size],
|
||||
TIMESTEPS: [batch_size, max_seq_len - 1],
|
||||
}
|
||||
explore: unused.
|
||||
timestep: unused.
|
||||
Returns:
|
||||
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = input_dict.copy()
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
input_dict.set_training(True)
|
||||
|
||||
actions, state_out, extra_fetches = self._compute_action_helper(input_dict)
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
# TODO: figure out what this with_lock does and why it's only on the helper method.
|
||||
@with_lock
|
||||
@override(TorchPolicyV2)
|
||||
def _compute_action_helper(self, input_dict):
|
||||
# Switch to eval mode.
|
||||
self.model.eval()
|
||||
|
||||
batch_size = input_dict[SampleBatch.OBS].shape[0]
|
||||
|
||||
# NOTE: This is probably the most confusing part of the code, made to work with
|
||||
# env_runner and SimpleListCollector during evaluation, and thus should
|
||||
# be changed for the new Policy and Connector API.
|
||||
# So I'll explain how it works.
|
||||
|
||||
# Add current timestep (+1 because -1 is first observation)
|
||||
# NOTE: ViewRequirement of timestep is -(max_seq_len-2):0.
|
||||
# The wierd limits is because RLlib treats initial obs as time -1,
|
||||
# and then 0 is (act, rew, next_obs), etc.
|
||||
# So we only collect max_seq_len-1 from the rollout and create the current
|
||||
# step here by adding 1.
|
||||
# Decision transformer treats initial observation as timestep 0, giving us
|
||||
# 0 is (obs, act, rew).
|
||||
timesteps = input_dict[SampleBatch.T]
|
||||
new_timestep = timesteps[:, -1:] + 1
|
||||
input_dict[SampleBatch.T] = torch.cat([timesteps, new_timestep], dim=1)
|
||||
|
||||
# mask out any padded value at start of rollout
|
||||
# NOTE: the other reason for doing this is that evaluation rollout front
|
||||
# pads timesteps with -1, so using this we can find out when we need to mask
|
||||
# out the front section of the batch.
|
||||
input_dict[SampleBatch.ATTENTION_MASKS] = torch.where(
|
||||
input_dict[SampleBatch.T] >= 0, 1.0, 0.0
|
||||
)
|
||||
|
||||
# Remove out-of-bound -1 timesteps after attention mask is calculated
|
||||
uncliped_timesteps = input_dict[SampleBatch.T]
|
||||
input_dict[SampleBatch.T] = torch.where(
|
||||
uncliped_timesteps < 0,
|
||||
torch.zeros_like(uncliped_timesteps),
|
||||
uncliped_timesteps,
|
||||
)
|
||||
|
||||
# Computes returns-to-go.
|
||||
# NOTE: There are two rtg calculations: updated_rtg and initial_rtg.
|
||||
# updated_rtg takes the previous rtg value (the ViewRequirement is
|
||||
# -(max_seq_len-1):-1), and subtracts the last reward from it.
|
||||
rtg = input_dict[SampleBatch.RETURNS_TO_GO]
|
||||
last_rtg = rtg[:, -1]
|
||||
last_reward = input_dict[SampleBatch.REWARDS]
|
||||
updated_rtg = last_rtg - last_reward
|
||||
# initial_rtg simply is filled with target_return.
|
||||
# These two are both only for the current timestep.
|
||||
initial_rtg = self.get_initial_rtg_tensor(
|
||||
(batch_size, 1), dtype=rtg.dtype, device=rtg.device
|
||||
)
|
||||
|
||||
# Then based on whether we are currently at the first timestep or not
|
||||
# we use the initial_rtg or updated_rtg.
|
||||
new_rtg = torch.where(new_timestep == 0, initial_rtg, updated_rtg[:, None])
|
||||
# Append the new_rtg to the batch.
|
||||
input_dict[SampleBatch.RETURNS_TO_GO] = torch.cat([rtg, new_rtg], dim=1)[
|
||||
..., None
|
||||
]
|
||||
|
||||
# Pad current action (is not actually attended to and used during inference)
|
||||
past_actions = input_dict[SampleBatch.ACTIONS]
|
||||
action_pad = torch.zeros(
|
||||
(batch_size, 1, *past_actions.shape[2:]),
|
||||
dtype=past_actions.dtype,
|
||||
device=past_actions.device,
|
||||
)
|
||||
input_dict[SampleBatch.ACTIONS] = torch.cat([past_actions, action_pad], dim=1)
|
||||
|
||||
# Run inference on model
|
||||
model_out, _ = self.model(input_dict) # noop, just returns obs.
|
||||
preds = self.model.get_prediction(model_out, input_dict)
|
||||
dist_inputs = preds[SampleBatch.ACTIONS][:, -1]
|
||||
|
||||
# Get the actions from the action_dist.
|
||||
action_dist = self.dist_class(dist_inputs, self.model)
|
||||
actions = action_dist.deterministic_sample()
|
||||
|
||||
# This is used by env_runner and is actually how it adds custom keys to
|
||||
# SimpleListCollector and allows ViewRequirements to work.
|
||||
# This is also used in user inference in get_next_input_dict, which takes
|
||||
# this output as one of the input.
|
||||
extra_fetches = {
|
||||
# new_rtg still has the leftover extra 3rd dimension for inference
|
||||
SampleBatch.RETURNS_TO_GO: new_rtg.squeeze(-1),
|
||||
SampleBatch.ACTION_DIST_INPUTS: dist_inputs,
|
||||
}
|
||||
|
||||
# Update our global timestep by the batch size.
|
||||
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
|
||||
|
||||
return convert_to_numpy((actions, [], extra_fetches))
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def loss(
|
||||
self,
|
||||
model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch,
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
model: The ModelV2 to run foward pass on.
|
||||
dist_class: The distribution of this policy.
|
||||
train_batch: Training SampleBatch.
|
||||
Keys and shapes: {
|
||||
OBS: [batch_size, max_seq_len, obs_dim],
|
||||
ACTIONS: [batch_size, max_seq_len, act_dim],
|
||||
RETURNS_TO_GO: [batch_size, max_seq_len + 1, 1],
|
||||
TIMESTEPS: [batch_size, max_seq_len],
|
||||
ATTENTION_MASKS: [batch_size, max_seq_len],
|
||||
}
|
||||
Returns:
|
||||
Loss scalar tensor.
|
||||
"""
|
||||
train_batch = self._lazy_tensor_dict(train_batch)
|
||||
|
||||
# policy forward and get prediction
|
||||
model_out, _ = self.model(train_batch) # noop, just returns obs.
|
||||
preds = self.model.get_prediction(model_out, train_batch)
|
||||
|
||||
# get the regression targets
|
||||
targets = self.model.get_targets(model_out, train_batch)
|
||||
|
||||
# get the attention masks for masked-loss
|
||||
masks = train_batch[SampleBatch.ATTENTION_MASKS]
|
||||
|
||||
# compute loss
|
||||
loss = self._masked_loss(preds, targets, masks)
|
||||
|
||||
self.log("cur_lr", torch.tensor(self.cur_lr))
|
||||
|
||||
return loss
|
||||
|
||||
def _masked_loss(self, preds, targets, masks):
|
||||
losses = []
|
||||
for key in targets:
|
||||
assert (
|
||||
key in preds
|
||||
), "for target {key} there is no prediction from the output of the model"
|
||||
loss_coef = self.config.get(f"loss_coef_{key}", 1.0)
|
||||
if self._is_discrete(key):
|
||||
loss = loss_coef * self._masked_cross_entropy_loss(
|
||||
preds[key], targets[key], masks
|
||||
)
|
||||
else:
|
||||
loss = loss_coef * self._masked_mse_loss(
|
||||
preds[key], targets[key], masks
|
||||
)
|
||||
|
||||
losses.append(loss)
|
||||
self.log(f"{key}_loss", loss)
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def _is_discrete(self, key):
|
||||
return key == SampleBatch.ACTIONS and isinstance(self.action_space, Discrete)
|
||||
|
||||
def _masked_cross_entropy_loss(
|
||||
self,
|
||||
preds: TensorType,
|
||||
targets: TensorType,
|
||||
masks: TensorType,
|
||||
) -> TensorType:
|
||||
"""Computes cross-entropy loss between preds and targets, subject to a mask.
|
||||
|
||||
Args:
|
||||
preds: logits of shape [B1, ..., Bn, M]
|
||||
targets: index targets for preds of shape [B1, ..., Bn]
|
||||
masks: 0 means don't compute loss, 1 means compute loss
|
||||
shape [B1, ..., Bn]
|
||||
|
||||
Returns:
|
||||
Scalar cross entropy loss.
|
||||
"""
|
||||
losses = F.cross_entropy(
|
||||
preds.reshape(-1, preds.shape[-1]), targets.reshape(-1), reduction="none"
|
||||
)
|
||||
losses = losses * masks.reshape(-1)
|
||||
return losses.mean()
|
||||
|
||||
def _masked_mse_loss(
|
||||
self,
|
||||
preds: TensorType,
|
||||
targets: TensorType,
|
||||
masks: TensorType,
|
||||
) -> TensorType:
|
||||
"""Computes MSE loss between preds and targets, subject to a mask.
|
||||
|
||||
Args:
|
||||
preds: logits of shape [B1, ..., Bn, M]
|
||||
targets: index targets for preds of shape [B1, ..., Bn]
|
||||
masks: 0 means don't compute loss, 1 means compute loss
|
||||
shape [B1, ..., Bn]
|
||||
|
||||
Returns:
|
||||
Scalar cross entropy loss.
|
||||
"""
|
||||
losses = F.mse_loss(preds, targets, reduction="none")
|
||||
losses = losses * masks.reshape(
|
||||
*masks.shape, *([1] * (len(preds.shape) - len(masks.shape)))
|
||||
)
|
||||
return losses.mean()
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def extra_grad_process(self, local_optimizer, loss):
|
||||
return apply_grad_clipping(self, local_optimizer, loss)
|
||||
|
||||
def log(self, key, value):
|
||||
# internal log function
|
||||
self.model.tower_stats[key] = value
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats_dict = {
|
||||
k: torch.stack(self.get_tower_stats(k)).mean().item()
|
||||
for k in self.model.tower_stats
|
||||
}
|
||||
return stats_dict
|
507
rllib/algorithms/dt/tests/test_dt_policy.py
Normal file
507
rllib/algorithms/dt/tests/test_dt_policy.py
Normal file
|
@ -0,0 +1,507 @@
|
|||
import unittest
|
||||
from typing import Dict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.algorithms.dt.dt_torch_policy import DTTorchPolicy
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
def _default_config():
|
||||
"""Base config to use."""
|
||||
return {
|
||||
"model": {
|
||||
"max_seq_len": 4,
|
||||
},
|
||||
"embed_dim": 32,
|
||||
"num_layers": 2,
|
||||
"horizon": 10,
|
||||
"num_heads": 2,
|
||||
"embed_pdrop": 0.1,
|
||||
"resid_pdrop": 0.1,
|
||||
"attn_pdrop": 0.1,
|
||||
"framework": "torch",
|
||||
"lr": 1e-3,
|
||||
"lr_schedule": None,
|
||||
"optimizer": {
|
||||
"weight_decay": 1e-4,
|
||||
"betas": [0.9, 0.99],
|
||||
},
|
||||
"target_return": 200.0,
|
||||
"loss_coef_actions": 1.0,
|
||||
"loss_coef_obs": 0,
|
||||
"loss_coef_returns_to_go": 0,
|
||||
"num_gpus": 0,
|
||||
"_fake_gpus": None,
|
||||
}
|
||||
|
||||
|
||||
def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
|
||||
for key in d1.keys():
|
||||
assert key in d2.keys()
|
||||
|
||||
for key in d2.keys():
|
||||
assert key in d1.keys()
|
||||
|
||||
for key in d1.keys():
|
||||
assert isinstance(d1[key], np.ndarray), "input_dict should only be numpy array."
|
||||
assert isinstance(d2[key], np.ndarray), "input_dict should only be numpy array."
|
||||
assert d1[key].shape == d2[key].shape, "input_dict are of different shape."
|
||||
assert np.allclose(d1[key], d2[key]), "input_dict values are not equal."
|
||||
|
||||
|
||||
class TestDTPolicy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_torch_postprocess_trajectory(self):
|
||||
"""Test postprocess_trajectory"""
|
||||
config = _default_config()
|
||||
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,))
|
||||
action_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||
|
||||
# Create policy
|
||||
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||
|
||||
# Generate input_dict with some data
|
||||
sample_batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.REWARDS: np.array([1.0, 2.0, 1.0, 1.0]),
|
||||
SampleBatch.EPS_ID: np.array([0, 0, 0, 0]),
|
||||
}
|
||||
)
|
||||
|
||||
# Do postprocess trajectory to calculate rtg
|
||||
sample_batch = policy.postprocess_trajectory(sample_batch)
|
||||
|
||||
# Assert that dones is correctly set
|
||||
assert SampleBatch.DONES in sample_batch, "`dones` isn't part of the batch."
|
||||
assert np.allclose(
|
||||
sample_batch[SampleBatch.DONES],
|
||||
np.array([False, False, False, True]),
|
||||
), "`dones` isn't set correctly."
|
||||
|
||||
def test_torch_input_dict(self):
|
||||
"""Test inference input_dict methods
|
||||
|
||||
This is a minimal version the test in test_dt.py.
|
||||
The shapes of the input_dict might be confusing but it makes sense in
|
||||
context of what the function is supposed to do.
|
||||
Check action_distribution_fn for an explanation.
|
||||
"""
|
||||
config = _default_config()
|
||||
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||
action_spaces = [
|
||||
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||
gym.spaces.Discrete(4),
|
||||
]
|
||||
|
||||
for action_space in action_spaces:
|
||||
# Create policy
|
||||
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||
|
||||
# initial obs and input_dict
|
||||
obs = np.array([0.0, 1.0, 2.0])
|
||||
input_dict = policy.get_initial_input_dict(obs)
|
||||
|
||||
# Check input_dict matches what it should be
|
||||
target_input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[0.0], [0.0], [0.0]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([0, 0, 0], dtype=np.int32)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[0.0, 0.0, 0.0], dtype=np.float32
|
||||
),
|
||||
SampleBatch.REWARDS: np.zeros((), dtype=np.float32),
|
||||
SampleBatch.T: np.array([-1, -1, -1], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
_assert_input_dict_equals(input_dict, target_input_dict)
|
||||
|
||||
# Get next input_dict
|
||||
input_dict = policy.get_next_input_dict(
|
||||
input_dict,
|
||||
action=(
|
||||
np.asarray([1.0], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.asarray(1, dtype=np.int32)
|
||||
),
|
||||
reward=1.0,
|
||||
next_obs=np.array([3.0, 4.0, 5.0]),
|
||||
extra={
|
||||
SampleBatch.RETURNS_TO_GO: config["target_return"],
|
||||
},
|
||||
)
|
||||
|
||||
# Check input_dict matches what it should be
|
||||
target_input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[0.0], [0.0], [1.0]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([0, 0, 1], dtype=np.int32)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[0.0, 0.0, config["target_return"]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.REWARDS: np.asarray(1.0, dtype=np.float32),
|
||||
SampleBatch.T: np.array([-1, -1, 0], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
_assert_input_dict_equals(input_dict, target_input_dict)
|
||||
|
||||
def test_torch_action(self):
|
||||
"""Test policy's action_distribution_fn and extra_action_out methods by
|
||||
calling compute_actions_from_input_dict which works those two methods
|
||||
in conjunction.
|
||||
"""
|
||||
config = _default_config()
|
||||
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||
action_spaces = [
|
||||
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||
gym.spaces.Discrete(4),
|
||||
]
|
||||
|
||||
for action_space in action_spaces:
|
||||
# Create policy
|
||||
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||
|
||||
# input_dict for initial observation
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[[0.0], [0.0], [0.0]]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([[0, 0, 0]], dtype=np.int32)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[[0.0, 0.0, 0.0]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.REWARDS: np.array([0.0], dtype=np.float32),
|
||||
SampleBatch.T: np.array([[-1, -1, -1]], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
|
||||
# Run compute_actions_from_input_dict
|
||||
actions, _, extras = policy.compute_actions_from_input_dict(
|
||||
input_dict,
|
||||
explore=False,
|
||||
timestep=None,
|
||||
)
|
||||
|
||||
# Check actions
|
||||
assert actions.shape == (
|
||||
1,
|
||||
*action_space.shape,
|
||||
), "actions has incorrect shape."
|
||||
|
||||
# Check extras
|
||||
assert (
|
||||
SampleBatch.RETURNS_TO_GO in extras
|
||||
), "extras should contain returns_to_go."
|
||||
assert extras[SampleBatch.RETURNS_TO_GO].shape == (
|
||||
1,
|
||||
), "extras['returns_to_go'] has incorrect shape."
|
||||
assert np.isclose(
|
||||
extras[SampleBatch.RETURNS_TO_GO],
|
||||
np.asarray([config["target_return"]], dtype=np.float32),
|
||||
), "extras['returns_to_go'] should contain target_return."
|
||||
|
||||
# input_dict for non-initial observation
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[[0.0], [0.0], [1.0]]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([[0, 0, 1]], dtype=np.int32)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[[0.0, 0.0, config["target_return"]]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.REWARDS: np.array([10.0], dtype=np.float32),
|
||||
SampleBatch.T: np.array([[-1, -1, 0]], dtype=np.int32),
|
||||
}
|
||||
)
|
||||
|
||||
# Run compute_actions_from_input_dict
|
||||
actions, _, extras = policy.compute_actions_from_input_dict(
|
||||
input_dict,
|
||||
explore=False,
|
||||
timestep=None,
|
||||
)
|
||||
|
||||
# Check actions
|
||||
assert actions.shape == (
|
||||
1,
|
||||
*action_space.shape,
|
||||
), "actions has incorrect shape."
|
||||
|
||||
# Check extras
|
||||
assert (
|
||||
SampleBatch.RETURNS_TO_GO in extras
|
||||
), "extras should contain returns_to_go."
|
||||
assert extras[SampleBatch.RETURNS_TO_GO].shape == (
|
||||
1,
|
||||
), "extras['returns_to_go'] has incorrect shape."
|
||||
assert np.isclose(
|
||||
extras[SampleBatch.RETURNS_TO_GO],
|
||||
np.asarray([config["target_return"] - 10.0], dtype=np.float32),
|
||||
), "extras['returns_to_go'] should contain target_return."
|
||||
|
||||
def test_loss(self):
|
||||
"""Test loss function."""
|
||||
config = _default_config()
|
||||
config["embed_pdrop"] = 0
|
||||
config["resid_pdrop"] = 0
|
||||
config["attn_pdrop"] = 0
|
||||
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||
action_spaces = [
|
||||
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||
gym.spaces.Discrete(4),
|
||||
]
|
||||
|
||||
for action_space in action_spaces:
|
||||
# Create policy
|
||||
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||
|
||||
# Run loss functions on batches with different items in the mask to make
|
||||
# sure the masks are working and making the loss the same.
|
||||
batch1 = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([[0, 0, 1, 3]], dtype=np.int64)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32),
|
||||
SampleBatch.ATTENTION_MASKS: np.array(
|
||||
[[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
batch2 = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[
|
||||
[1.0, 1.0, -1.0],
|
||||
[1.0, 10.0, 12.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[[1.0], [-0.5], [1.0], [0.5]]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([[2, 1, 1, 3]], dtype=np.int64)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[[[200.0], [-10.0], [100.0], [90.0], [80.0]]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.T: np.array([[9, 3, 0, 1]], dtype=np.int32),
|
||||
SampleBatch.ATTENTION_MASKS: np.array(
|
||||
[[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
loss1 = policy.loss(policy.model, policy.dist_class, batch1)
|
||||
loss2 = policy.loss(policy.model, policy.dist_class, batch2)
|
||||
|
||||
loss1 = loss1.detach().cpu().item()
|
||||
loss2 = loss2.detach().cpu().item()
|
||||
|
||||
assert np.isclose(loss1, loss2), "Masks are not working for losses."
|
||||
|
||||
# Run loss on a widely different batch and make sure the loss is different.
|
||||
batch3 = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[
|
||||
[1.0, 1.0, -20.0],
|
||||
[0.1, 10.0, 12.0],
|
||||
[1.4, 12.0, -9.0],
|
||||
[6.0, 40.0, -2.0],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[[2.0], [-1.5], [0.2], [0.1]]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([[1, 3, 0, 2]], dtype=np.int64)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[[[90.0], [80.0], [70.0], [60.0], [50.0]]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.T: np.array([[3, 4, 5, 6]], dtype=np.int32),
|
||||
SampleBatch.ATTENTION_MASKS: np.array(
|
||||
[[1.0, 1.0, 1.0, 1.0]], dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
loss3 = policy.loss(policy.model, policy.dist_class, batch3)
|
||||
loss3 = loss3.detach().cpu().item()
|
||||
|
||||
assert not np.isclose(
|
||||
loss1, loss3
|
||||
), "Widely different inputs are giving the same loss value."
|
||||
|
||||
def test_loss_coef(self):
|
||||
"""Test the loss_coef_{key} config options."""
|
||||
|
||||
config = _default_config()
|
||||
config["embed_pdrop"] = 0
|
||||
config["resid_pdrop"] = 0
|
||||
config["attn_pdrop"] = 0
|
||||
# set initial action coef to 0
|
||||
config["loss_coef_actions"] = 0
|
||||
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||
action_spaces = [
|
||||
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||
gym.spaces.Discrete(4),
|
||||
]
|
||||
|
||||
for action_space in action_spaces:
|
||||
batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: np.array(
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
],
|
||||
dtype=np.float32,
|
||||
),
|
||||
SampleBatch.ACTIONS: (
|
||||
np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32)
|
||||
if isinstance(action_space, gym.spaces.Box)
|
||||
else np.array([[0, 0, 1, 3]], dtype=np.int64)
|
||||
),
|
||||
SampleBatch.RETURNS_TO_GO: np.array(
|
||||
[[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32
|
||||
),
|
||||
SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32),
|
||||
SampleBatch.ATTENTION_MASKS: np.array(
|
||||
[[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
keys = [SampleBatch.ACTIONS, SampleBatch.OBS, SampleBatch.RETURNS_TO_GO]
|
||||
for key in keys:
|
||||
# create policy and run loss with different coefs
|
||||
# create policy 1 with coef = 1
|
||||
config1 = config.copy()
|
||||
config1[f"loss_coef_{key}"] = 1.0
|
||||
policy1 = DTTorchPolicy(observation_space, action_space, config1)
|
||||
|
||||
loss1 = policy1.loss(policy1.model, policy1.dist_class, batch)
|
||||
loss1 = loss1.detach().cpu().item()
|
||||
|
||||
# create policy 2 with coef = 10
|
||||
config2 = config.copy()
|
||||
config2[f"loss_coef_{key}"] = 10.0
|
||||
policy2 = DTTorchPolicy(observation_space, action_space, config2)
|
||||
# copy the weights over so they output the same loss without scaling
|
||||
policy2.set_state(policy1.get_state())
|
||||
policy2.set_weights(policy1.get_weights())
|
||||
|
||||
loss2 = policy2.loss(policy2.model, policy2.dist_class, batch)
|
||||
loss2 = loss2.detach().cpu().item()
|
||||
|
||||
# compare loss, should be factor of 10 difference
|
||||
self.assertAlmostEqual(
|
||||
loss2 / loss1,
|
||||
10.0,
|
||||
places=3,
|
||||
msg="the two losses should be different to a factor of 10.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue