mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add DTTorchPolicy (#27889)
This commit is contained in:
parent
4692e8d802
commit
9330d8f244
3 changed files with 1095 additions and 0 deletions
|
@ -921,6 +921,13 @@ py_test(
|
||||||
srcs = ["algorithms/dt/tests/test_dt_model.py"]
|
srcs = ["algorithms/dt/tests/test_dt_model.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_dt_policy",
|
||||||
|
tags = ["team:rllib", "algorithms_dir"],
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["algorithms/dt/tests/test_dt_policy.py"]
|
||||||
|
)
|
||||||
|
|
||||||
# ES
|
# ES
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_es",
|
name = "test_es",
|
||||||
|
|
581
rllib/algorithms/dt/dt_torch_policy.py
Normal file
581
rllib/algorithms/dt/dt_torch_policy.py
Normal file
|
@ -0,0 +1,581 @@
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
)
|
||||||
|
|
||||||
|
import tree
|
||||||
|
from gym.spaces import Discrete, Box
|
||||||
|
|
||||||
|
from ray.rllib.algorithms.dt.dt_torch_model import DTTorchModel
|
||||||
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
|
from ray.rllib.models.torch.mingpt import configure_gpt_optimizer
|
||||||
|
from ray.rllib.models.torch.torch_action_dist import (
|
||||||
|
TorchDistributionWrapper,
|
||||||
|
TorchCategorical,
|
||||||
|
TorchDeterministic,
|
||||||
|
)
|
||||||
|
from ray.rllib.policy.torch_mixins import LearningRateSchedule
|
||||||
|
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||||
|
from ray.rllib.utils.numpy import convert_to_numpy
|
||||||
|
from ray.rllib.utils.threading import with_lock
|
||||||
|
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
TrainerConfigDict,
|
||||||
|
TensorType,
|
||||||
|
TensorStructType,
|
||||||
|
TensorShape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.rllib.evaluation import Episode # noqa
|
||||||
|
|
||||||
|
torch, nn = try_import_torch()
|
||||||
|
F = nn.functional
|
||||||
|
|
||||||
|
|
||||||
|
class DTTorchPolicy(LearningRateSchedule, TorchPolicyV2):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: gym.spaces.Space,
|
||||||
|
action_space: gym.spaces.Space,
|
||||||
|
config: TrainerConfigDict,
|
||||||
|
):
|
||||||
|
LearningRateSchedule.__init__(
|
||||||
|
self,
|
||||||
|
config["lr"],
|
||||||
|
config["lr_schedule"],
|
||||||
|
)
|
||||||
|
|
||||||
|
TorchPolicyV2.__init__(
|
||||||
|
self,
|
||||||
|
observation_space,
|
||||||
|
action_space,
|
||||||
|
config,
|
||||||
|
max_seq_len=config["model"]["max_seq_len"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def make_model_and_action_dist(
|
||||||
|
self,
|
||||||
|
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
||||||
|
# Model
|
||||||
|
model_config = self.config["model"]
|
||||||
|
# TODO: make these better with better AlgorithmConfig options.
|
||||||
|
model_config.update(
|
||||||
|
embed_dim=self.config["embed_dim"],
|
||||||
|
max_ep_len=self.config["horizon"],
|
||||||
|
num_layers=self.config["num_layers"],
|
||||||
|
num_heads=self.config["num_heads"],
|
||||||
|
embed_pdrop=self.config["embed_pdrop"],
|
||||||
|
resid_pdrop=self.config["resid_pdrop"],
|
||||||
|
attn_pdrop=self.config["attn_pdrop"],
|
||||||
|
use_obs_output=self.config.get("loss_coef_obs", 0) > 0,
|
||||||
|
use_return_output=self.config.get("loss_coef_returns_to_go", 0) > 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_outputs = int(np.product(self.observation_space.shape))
|
||||||
|
|
||||||
|
model = ModelCatalog.get_model_v2(
|
||||||
|
obs_space=self.observation_space,
|
||||||
|
action_space=self.action_space,
|
||||||
|
num_outputs=num_outputs,
|
||||||
|
model_config=model_config,
|
||||||
|
framework=self.config["framework"],
|
||||||
|
model_interface=None,
|
||||||
|
default_model=DTTorchModel,
|
||||||
|
name="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action Distribution
|
||||||
|
if isinstance(self.action_space, Discrete):
|
||||||
|
action_dist = TorchCategorical
|
||||||
|
elif isinstance(self.action_space, Box):
|
||||||
|
action_dist = TorchDeterministic
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return model, action_dist
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def optimizer(
|
||||||
|
self,
|
||||||
|
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
|
||||||
|
optimizer = configure_gpt_optimizer(
|
||||||
|
model=self.model,
|
||||||
|
learning_rate=self.config["lr"],
|
||||||
|
weight_decay=self.config["optimizer"]["weight_decay"],
|
||||||
|
betas=self.config["optimizer"]["betas"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def postprocess_trajectory(
|
||||||
|
self,
|
||||||
|
sample_batch: SampleBatch,
|
||||||
|
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
||||||
|
episode: Optional["Episode"] = None,
|
||||||
|
) -> SampleBatch:
|
||||||
|
"""Called by offline data reader after loading in one episode.
|
||||||
|
Adds a done flag at the end of trajectory so that SegmentationBuffer can
|
||||||
|
split using the done flag to avoid duplicate trajectories.
|
||||||
|
"""
|
||||||
|
ep_len = sample_batch.env_steps()
|
||||||
|
sample_batch[SampleBatch.DONES] = np.array([False] * (ep_len - 1) + [True])
|
||||||
|
return sample_batch
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
def get_initial_input_dict(self, observation: TensorStructType) -> SampleBatch:
|
||||||
|
"""Get the initial input_dict to be passed into compute_single_action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: first (unbatched) observation from env.reset()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The input_dict for inference: {
|
||||||
|
OBS: [max_seq_len, obs_dim] array,
|
||||||
|
ACTIONS: [max_seq_len - 1, act_dim] array,
|
||||||
|
RETURNS_TO_GO: [max_seq_len - 1] array,
|
||||||
|
REWARDS: scalar,
|
||||||
|
TIMESTEPS: [max_seq_len - 1] array,
|
||||||
|
}
|
||||||
|
Note the sequence lengths are different, and is specified as per
|
||||||
|
view_requirements. Explanations in action_distribution_fn method.
|
||||||
|
"""
|
||||||
|
observation = convert_to_numpy(observation)
|
||||||
|
obs_shape = observation.shape
|
||||||
|
obs_dtype = observation.dtype
|
||||||
|
|
||||||
|
act_shape = self.action_space.shape
|
||||||
|
act_dtype = self.action_space.dtype
|
||||||
|
|
||||||
|
# Here we will pad all the required inputs to its proper sequence length
|
||||||
|
# as their ViewRequirement.
|
||||||
|
|
||||||
|
observations = np.concatenate(
|
||||||
|
[
|
||||||
|
np.zeros((self.max_seq_len - 1, *obs_shape), dtype=obs_dtype),
|
||||||
|
observation[None],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions = np.zeros((self.max_seq_len - 1, *act_shape), dtype=act_dtype)
|
||||||
|
|
||||||
|
rtg = np.zeros(self.max_seq_len - 1, dtype=np.float32)
|
||||||
|
|
||||||
|
rewards = np.zeros((), dtype=np.float32)
|
||||||
|
|
||||||
|
# -1 for masking in action_distribution_fn
|
||||||
|
timesteps = np.full(self.max_seq_len - 1, fill_value=-1, dtype=np.int32)
|
||||||
|
|
||||||
|
input_dict = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: observations,
|
||||||
|
SampleBatch.ACTIONS: actions,
|
||||||
|
SampleBatch.RETURNS_TO_GO: rtg,
|
||||||
|
SampleBatch.REWARDS: rewards,
|
||||||
|
SampleBatch.T: timesteps,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return input_dict
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
def get_next_input_dict(
|
||||||
|
self,
|
||||||
|
input_dict: SampleBatch,
|
||||||
|
action: TensorStructType,
|
||||||
|
reward: TensorStructType,
|
||||||
|
next_obs: TensorStructType,
|
||||||
|
extra: Dict[str, TensorType],
|
||||||
|
) -> SampleBatch:
|
||||||
|
"""Returns a new input_dict after stepping through the environment once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dict: the input dict passed into compute_single_action.
|
||||||
|
action: the (unbatched) action taken this step.
|
||||||
|
reward: the (unbatched) reward from env.step
|
||||||
|
next_obs: the (unbatached) next observation from env.step
|
||||||
|
extra: the extra action out from compute_single_action.
|
||||||
|
In this case contains current returns to go *before* the current
|
||||||
|
reward is subtracted from target_return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new input_dict to be passed into compute_single_action.
|
||||||
|
The input_dict for inference: {
|
||||||
|
OBS: [max_seq_len, obs_dim] array,
|
||||||
|
ACTIONS: [max_seq_len - 1, act_dim] array,
|
||||||
|
RETURNS_TO_GO: [max_seq_len - 1] array,
|
||||||
|
REWARDS: scalar,
|
||||||
|
TIMESTEPS: [max_seq_len - 1] array,
|
||||||
|
}
|
||||||
|
Note the sequence lengths are different, and is specified as per
|
||||||
|
view_requirements. Explanations in action_distribution_fn method.
|
||||||
|
"""
|
||||||
|
# creates a copy of input_dict with only numpy arrays
|
||||||
|
input_dict = tree.map_structure(convert_to_numpy, input_dict)
|
||||||
|
# convert everything else to numpy as well
|
||||||
|
action, reward, next_obs, extra = convert_to_numpy(
|
||||||
|
(action, reward, next_obs, extra)
|
||||||
|
)
|
||||||
|
|
||||||
|
# check dimensions
|
||||||
|
assert input_dict[SampleBatch.OBS].shape == (
|
||||||
|
self.max_seq_len,
|
||||||
|
*self.observation_space.shape,
|
||||||
|
)
|
||||||
|
assert input_dict[SampleBatch.ACTIONS].shape == (
|
||||||
|
self.max_seq_len - 1,
|
||||||
|
*self.action_space.shape,
|
||||||
|
)
|
||||||
|
assert input_dict[SampleBatch.RETURNS_TO_GO].shape == (self.max_seq_len - 1,)
|
||||||
|
assert input_dict[SampleBatch.T].shape == (self.max_seq_len - 1,)
|
||||||
|
|
||||||
|
# Shift observations
|
||||||
|
input_dict[SampleBatch.OBS] = np.concatenate(
|
||||||
|
[
|
||||||
|
input_dict[SampleBatch.OBS][1:],
|
||||||
|
next_obs[None],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shift actions
|
||||||
|
input_dict[SampleBatch.ACTIONS] = np.concatenate(
|
||||||
|
[
|
||||||
|
input_dict[SampleBatch.ACTIONS][1:],
|
||||||
|
action[None],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reward is not a sequence, it's only used to calculate next rtg.
|
||||||
|
input_dict[SampleBatch.REWARDS] = np.asarray(reward)
|
||||||
|
|
||||||
|
# See methods action_distribution_fn and extra_action_out for an explanation
|
||||||
|
# of why this is done.
|
||||||
|
input_dict[SampleBatch.RETURNS_TO_GO] = np.concatenate(
|
||||||
|
[
|
||||||
|
input_dict[SampleBatch.RETURNS_TO_GO][1:],
|
||||||
|
np.asarray(extra[SampleBatch.RETURNS_TO_GO])[None],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shift and increment timesteps
|
||||||
|
input_dict[SampleBatch.T] = np.concatenate(
|
||||||
|
[
|
||||||
|
input_dict[SampleBatch.T][1:],
|
||||||
|
input_dict[SampleBatch.T][-1:] + 1,
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_dict
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def get_initial_rtg_tensor(
|
||||||
|
self,
|
||||||
|
shape: TensorShape,
|
||||||
|
dtype: Optional[Type] = torch.float32,
|
||||||
|
device: Optional["torch.device"] = None,
|
||||||
|
):
|
||||||
|
"""Returns a initial/target returns-to-go tensor of the given shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Shape of the rtg tensor.
|
||||||
|
dtype: Type of the data in the tensor. Defaults to torch.float32.
|
||||||
|
device: The device this tensor should be on. Defaults to self.device.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = self.device
|
||||||
|
if dtype is None:
|
||||||
|
device = torch.float32
|
||||||
|
|
||||||
|
assert self.config["target_return"] is not None, "Must specify target_return."
|
||||||
|
initial_rtg = torch.full(
|
||||||
|
shape,
|
||||||
|
fill_value=self.config["target_return"],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return initial_rtg
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
@DeveloperAPI
|
||||||
|
def compute_actions(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
|
||||||
|
raise ValueError("Please use compute_actions_from_input_dict instead.")
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def compute_actions_from_input_dict(
|
||||||
|
self,
|
||||||
|
input_dict: SampleBatch,
|
||||||
|
explore: bool = None,
|
||||||
|
timestep: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dict: input_dict (that contains a batch dimension for each value).
|
||||||
|
Keys and shapes: {
|
||||||
|
OBS: [batch_size, max_seq_len, obs_dim],
|
||||||
|
ACTIONS: [batch_size, max_seq_len - 1, act_dim],
|
||||||
|
RETURNS_TO_GO: [batch_size, max_seq_len - 1],
|
||||||
|
REWARDS: [batch_size],
|
||||||
|
TIMESTEPS: [batch_size, max_seq_len - 1],
|
||||||
|
}
|
||||||
|
explore: unused.
|
||||||
|
timestep: unused.
|
||||||
|
Returns:
|
||||||
|
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||||
|
input_dict = input_dict.copy()
|
||||||
|
input_dict = self._lazy_tensor_dict(input_dict)
|
||||||
|
input_dict.set_training(True)
|
||||||
|
|
||||||
|
actions, state_out, extra_fetches = self._compute_action_helper(input_dict)
|
||||||
|
return actions, state_out, extra_fetches
|
||||||
|
|
||||||
|
# TODO: figure out what this with_lock does and why it's only on the helper method.
|
||||||
|
@with_lock
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def _compute_action_helper(self, input_dict):
|
||||||
|
# Switch to eval mode.
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
batch_size = input_dict[SampleBatch.OBS].shape[0]
|
||||||
|
|
||||||
|
# NOTE: This is probably the most confusing part of the code, made to work with
|
||||||
|
# env_runner and SimpleListCollector during evaluation, and thus should
|
||||||
|
# be changed for the new Policy and Connector API.
|
||||||
|
# So I'll explain how it works.
|
||||||
|
|
||||||
|
# Add current timestep (+1 because -1 is first observation)
|
||||||
|
# NOTE: ViewRequirement of timestep is -(max_seq_len-2):0.
|
||||||
|
# The wierd limits is because RLlib treats initial obs as time -1,
|
||||||
|
# and then 0 is (act, rew, next_obs), etc.
|
||||||
|
# So we only collect max_seq_len-1 from the rollout and create the current
|
||||||
|
# step here by adding 1.
|
||||||
|
# Decision transformer treats initial observation as timestep 0, giving us
|
||||||
|
# 0 is (obs, act, rew).
|
||||||
|
timesteps = input_dict[SampleBatch.T]
|
||||||
|
new_timestep = timesteps[:, -1:] + 1
|
||||||
|
input_dict[SampleBatch.T] = torch.cat([timesteps, new_timestep], dim=1)
|
||||||
|
|
||||||
|
# mask out any padded value at start of rollout
|
||||||
|
# NOTE: the other reason for doing this is that evaluation rollout front
|
||||||
|
# pads timesteps with -1, so using this we can find out when we need to mask
|
||||||
|
# out the front section of the batch.
|
||||||
|
input_dict[SampleBatch.ATTENTION_MASKS] = torch.where(
|
||||||
|
input_dict[SampleBatch.T] >= 0, 1.0, 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove out-of-bound -1 timesteps after attention mask is calculated
|
||||||
|
uncliped_timesteps = input_dict[SampleBatch.T]
|
||||||
|
input_dict[SampleBatch.T] = torch.where(
|
||||||
|
uncliped_timesteps < 0,
|
||||||
|
torch.zeros_like(uncliped_timesteps),
|
||||||
|
uncliped_timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computes returns-to-go.
|
||||||
|
# NOTE: There are two rtg calculations: updated_rtg and initial_rtg.
|
||||||
|
# updated_rtg takes the previous rtg value (the ViewRequirement is
|
||||||
|
# -(max_seq_len-1):-1), and subtracts the last reward from it.
|
||||||
|
rtg = input_dict[SampleBatch.RETURNS_TO_GO]
|
||||||
|
last_rtg = rtg[:, -1]
|
||||||
|
last_reward = input_dict[SampleBatch.REWARDS]
|
||||||
|
updated_rtg = last_rtg - last_reward
|
||||||
|
# initial_rtg simply is filled with target_return.
|
||||||
|
# These two are both only for the current timestep.
|
||||||
|
initial_rtg = self.get_initial_rtg_tensor(
|
||||||
|
(batch_size, 1), dtype=rtg.dtype, device=rtg.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then based on whether we are currently at the first timestep or not
|
||||||
|
# we use the initial_rtg or updated_rtg.
|
||||||
|
new_rtg = torch.where(new_timestep == 0, initial_rtg, updated_rtg[:, None])
|
||||||
|
# Append the new_rtg to the batch.
|
||||||
|
input_dict[SampleBatch.RETURNS_TO_GO] = torch.cat([rtg, new_rtg], dim=1)[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pad current action (is not actually attended to and used during inference)
|
||||||
|
past_actions = input_dict[SampleBatch.ACTIONS]
|
||||||
|
action_pad = torch.zeros(
|
||||||
|
(batch_size, 1, *past_actions.shape[2:]),
|
||||||
|
dtype=past_actions.dtype,
|
||||||
|
device=past_actions.device,
|
||||||
|
)
|
||||||
|
input_dict[SampleBatch.ACTIONS] = torch.cat([past_actions, action_pad], dim=1)
|
||||||
|
|
||||||
|
# Run inference on model
|
||||||
|
model_out, _ = self.model(input_dict) # noop, just returns obs.
|
||||||
|
preds = self.model.get_prediction(model_out, input_dict)
|
||||||
|
dist_inputs = preds[SampleBatch.ACTIONS][:, -1]
|
||||||
|
|
||||||
|
# Get the actions from the action_dist.
|
||||||
|
action_dist = self.dist_class(dist_inputs, self.model)
|
||||||
|
actions = action_dist.deterministic_sample()
|
||||||
|
|
||||||
|
# This is used by env_runner and is actually how it adds custom keys to
|
||||||
|
# SimpleListCollector and allows ViewRequirements to work.
|
||||||
|
# This is also used in user inference in get_next_input_dict, which takes
|
||||||
|
# this output as one of the input.
|
||||||
|
extra_fetches = {
|
||||||
|
# new_rtg still has the leftover extra 3rd dimension for inference
|
||||||
|
SampleBatch.RETURNS_TO_GO: new_rtg.squeeze(-1),
|
||||||
|
SampleBatch.ACTION_DIST_INPUTS: dist_inputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update our global timestep by the batch size.
|
||||||
|
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
|
||||||
|
|
||||||
|
return convert_to_numpy((actions, [], extra_fetches))
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def loss(
|
||||||
|
self,
|
||||||
|
model: ModelV2,
|
||||||
|
dist_class: Type[TorchDistributionWrapper],
|
||||||
|
train_batch: SampleBatch,
|
||||||
|
) -> Union[TensorType, List[TensorType]]:
|
||||||
|
"""Loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The ModelV2 to run foward pass on.
|
||||||
|
dist_class: The distribution of this policy.
|
||||||
|
train_batch: Training SampleBatch.
|
||||||
|
Keys and shapes: {
|
||||||
|
OBS: [batch_size, max_seq_len, obs_dim],
|
||||||
|
ACTIONS: [batch_size, max_seq_len, act_dim],
|
||||||
|
RETURNS_TO_GO: [batch_size, max_seq_len + 1, 1],
|
||||||
|
TIMESTEPS: [batch_size, max_seq_len],
|
||||||
|
ATTENTION_MASKS: [batch_size, max_seq_len],
|
||||||
|
}
|
||||||
|
Returns:
|
||||||
|
Loss scalar tensor.
|
||||||
|
"""
|
||||||
|
train_batch = self._lazy_tensor_dict(train_batch)
|
||||||
|
|
||||||
|
# policy forward and get prediction
|
||||||
|
model_out, _ = self.model(train_batch) # noop, just returns obs.
|
||||||
|
preds = self.model.get_prediction(model_out, train_batch)
|
||||||
|
|
||||||
|
# get the regression targets
|
||||||
|
targets = self.model.get_targets(model_out, train_batch)
|
||||||
|
|
||||||
|
# get the attention masks for masked-loss
|
||||||
|
masks = train_batch[SampleBatch.ATTENTION_MASKS]
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
loss = self._masked_loss(preds, targets, masks)
|
||||||
|
|
||||||
|
self.log("cur_lr", torch.tensor(self.cur_lr))
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _masked_loss(self, preds, targets, masks):
|
||||||
|
losses = []
|
||||||
|
for key in targets:
|
||||||
|
assert (
|
||||||
|
key in preds
|
||||||
|
), "for target {key} there is no prediction from the output of the model"
|
||||||
|
loss_coef = self.config.get(f"loss_coef_{key}", 1.0)
|
||||||
|
if self._is_discrete(key):
|
||||||
|
loss = loss_coef * self._masked_cross_entropy_loss(
|
||||||
|
preds[key], targets[key], masks
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = loss_coef * self._masked_mse_loss(
|
||||||
|
preds[key], targets[key], masks
|
||||||
|
)
|
||||||
|
|
||||||
|
losses.append(loss)
|
||||||
|
self.log(f"{key}_loss", loss)
|
||||||
|
|
||||||
|
return sum(losses)
|
||||||
|
|
||||||
|
def _is_discrete(self, key):
|
||||||
|
return key == SampleBatch.ACTIONS and isinstance(self.action_space, Discrete)
|
||||||
|
|
||||||
|
def _masked_cross_entropy_loss(
|
||||||
|
self,
|
||||||
|
preds: TensorType,
|
||||||
|
targets: TensorType,
|
||||||
|
masks: TensorType,
|
||||||
|
) -> TensorType:
|
||||||
|
"""Computes cross-entropy loss between preds and targets, subject to a mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: logits of shape [B1, ..., Bn, M]
|
||||||
|
targets: index targets for preds of shape [B1, ..., Bn]
|
||||||
|
masks: 0 means don't compute loss, 1 means compute loss
|
||||||
|
shape [B1, ..., Bn]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar cross entropy loss.
|
||||||
|
"""
|
||||||
|
losses = F.cross_entropy(
|
||||||
|
preds.reshape(-1, preds.shape[-1]), targets.reshape(-1), reduction="none"
|
||||||
|
)
|
||||||
|
losses = losses * masks.reshape(-1)
|
||||||
|
return losses.mean()
|
||||||
|
|
||||||
|
def _masked_mse_loss(
|
||||||
|
self,
|
||||||
|
preds: TensorType,
|
||||||
|
targets: TensorType,
|
||||||
|
masks: TensorType,
|
||||||
|
) -> TensorType:
|
||||||
|
"""Computes MSE loss between preds and targets, subject to a mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds: logits of shape [B1, ..., Bn, M]
|
||||||
|
targets: index targets for preds of shape [B1, ..., Bn]
|
||||||
|
masks: 0 means don't compute loss, 1 means compute loss
|
||||||
|
shape [B1, ..., Bn]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar cross entropy loss.
|
||||||
|
"""
|
||||||
|
losses = F.mse_loss(preds, targets, reduction="none")
|
||||||
|
losses = losses * masks.reshape(
|
||||||
|
*masks.shape, *([1] * (len(preds.shape) - len(masks.shape)))
|
||||||
|
)
|
||||||
|
return losses.mean()
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def extra_grad_process(self, local_optimizer, loss):
|
||||||
|
return apply_grad_clipping(self, local_optimizer, loss)
|
||||||
|
|
||||||
|
def log(self, key, value):
|
||||||
|
# internal log function
|
||||||
|
self.model.tower_stats[key] = value
|
||||||
|
|
||||||
|
@override(TorchPolicyV2)
|
||||||
|
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||||
|
stats_dict = {
|
||||||
|
k: torch.stack(self.get_tower_stats(k)).mean().item()
|
||||||
|
for k in self.model.tower_stats
|
||||||
|
}
|
||||||
|
return stats_dict
|
507
rllib/algorithms/dt/tests/test_dt_policy.py
Normal file
507
rllib/algorithms/dt/tests/test_dt_policy.py
Normal file
|
@ -0,0 +1,507 @@
|
||||||
|
import unittest
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
|
from ray.rllib.algorithms.dt.dt_torch_policy import DTTorchPolicy
|
||||||
|
|
||||||
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
|
||||||
|
def _default_config():
|
||||||
|
"""Base config to use."""
|
||||||
|
return {
|
||||||
|
"model": {
|
||||||
|
"max_seq_len": 4,
|
||||||
|
},
|
||||||
|
"embed_dim": 32,
|
||||||
|
"num_layers": 2,
|
||||||
|
"horizon": 10,
|
||||||
|
"num_heads": 2,
|
||||||
|
"embed_pdrop": 0.1,
|
||||||
|
"resid_pdrop": 0.1,
|
||||||
|
"attn_pdrop": 0.1,
|
||||||
|
"framework": "torch",
|
||||||
|
"lr": 1e-3,
|
||||||
|
"lr_schedule": None,
|
||||||
|
"optimizer": {
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"betas": [0.9, 0.99],
|
||||||
|
},
|
||||||
|
"target_return": 200.0,
|
||||||
|
"loss_coef_actions": 1.0,
|
||||||
|
"loss_coef_obs": 0,
|
||||||
|
"loss_coef_returns_to_go": 0,
|
||||||
|
"num_gpus": 0,
|
||||||
|
"_fake_gpus": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
|
||||||
|
for key in d1.keys():
|
||||||
|
assert key in d2.keys()
|
||||||
|
|
||||||
|
for key in d2.keys():
|
||||||
|
assert key in d1.keys()
|
||||||
|
|
||||||
|
for key in d1.keys():
|
||||||
|
assert isinstance(d1[key], np.ndarray), "input_dict should only be numpy array."
|
||||||
|
assert isinstance(d2[key], np.ndarray), "input_dict should only be numpy array."
|
||||||
|
assert d1[key].shape == d2[key].shape, "input_dict are of different shape."
|
||||||
|
assert np.allclose(d1[key], d2[key]), "input_dict values are not equal."
|
||||||
|
|
||||||
|
|
||||||
|
class TestDTPolicy(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_torch_postprocess_trajectory(self):
|
||||||
|
"""Test postprocess_trajectory"""
|
||||||
|
config = _default_config()
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,))
|
||||||
|
action_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||||
|
|
||||||
|
# Create policy
|
||||||
|
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||||
|
|
||||||
|
# Generate input_dict with some data
|
||||||
|
sample_batch = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.REWARDS: np.array([1.0, 2.0, 1.0, 1.0]),
|
||||||
|
SampleBatch.EPS_ID: np.array([0, 0, 0, 0]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Do postprocess trajectory to calculate rtg
|
||||||
|
sample_batch = policy.postprocess_trajectory(sample_batch)
|
||||||
|
|
||||||
|
# Assert that dones is correctly set
|
||||||
|
assert SampleBatch.DONES in sample_batch, "`dones` isn't part of the batch."
|
||||||
|
assert np.allclose(
|
||||||
|
sample_batch[SampleBatch.DONES],
|
||||||
|
np.array([False, False, False, True]),
|
||||||
|
), "`dones` isn't set correctly."
|
||||||
|
|
||||||
|
def test_torch_input_dict(self):
|
||||||
|
"""Test inference input_dict methods
|
||||||
|
|
||||||
|
This is a minimal version the test in test_dt.py.
|
||||||
|
The shapes of the input_dict might be confusing but it makes sense in
|
||||||
|
context of what the function is supposed to do.
|
||||||
|
Check action_distribution_fn for an explanation.
|
||||||
|
"""
|
||||||
|
config = _default_config()
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||||
|
action_spaces = [
|
||||||
|
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||||
|
gym.spaces.Discrete(4),
|
||||||
|
]
|
||||||
|
|
||||||
|
for action_space in action_spaces:
|
||||||
|
# Create policy
|
||||||
|
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||||
|
|
||||||
|
# initial obs and input_dict
|
||||||
|
obs = np.array([0.0, 1.0, 2.0])
|
||||||
|
input_dict = policy.get_initial_input_dict(obs)
|
||||||
|
|
||||||
|
# Check input_dict matches what it should be
|
||||||
|
target_input_dict = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[0.0], [0.0], [0.0]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([0, 0, 0], dtype=np.int32)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[0.0, 0.0, 0.0], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.REWARDS: np.zeros((), dtype=np.float32),
|
||||||
|
SampleBatch.T: np.array([-1, -1, -1], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_assert_input_dict_equals(input_dict, target_input_dict)
|
||||||
|
|
||||||
|
# Get next input_dict
|
||||||
|
input_dict = policy.get_next_input_dict(
|
||||||
|
input_dict,
|
||||||
|
action=(
|
||||||
|
np.asarray([1.0], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.asarray(1, dtype=np.int32)
|
||||||
|
),
|
||||||
|
reward=1.0,
|
||||||
|
next_obs=np.array([3.0, 4.0, 5.0]),
|
||||||
|
extra={
|
||||||
|
SampleBatch.RETURNS_TO_GO: config["target_return"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check input_dict matches what it should be
|
||||||
|
target_input_dict = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[0.0], [0.0], [1.0]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([0, 0, 1], dtype=np.int32)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[0.0, 0.0, config["target_return"]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.REWARDS: np.asarray(1.0, dtype=np.float32),
|
||||||
|
SampleBatch.T: np.array([-1, -1, 0], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_assert_input_dict_equals(input_dict, target_input_dict)
|
||||||
|
|
||||||
|
def test_torch_action(self):
|
||||||
|
"""Test policy's action_distribution_fn and extra_action_out methods by
|
||||||
|
calling compute_actions_from_input_dict which works those two methods
|
||||||
|
in conjunction.
|
||||||
|
"""
|
||||||
|
config = _default_config()
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||||
|
action_spaces = [
|
||||||
|
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||||
|
gym.spaces.Discrete(4),
|
||||||
|
]
|
||||||
|
|
||||||
|
for action_space in action_spaces:
|
||||||
|
# Create policy
|
||||||
|
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||||
|
|
||||||
|
# input_dict for initial observation
|
||||||
|
input_dict = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[[0.0], [0.0], [0.0]]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([[0, 0, 0]], dtype=np.int32)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[[0.0, 0.0, 0.0]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.REWARDS: np.array([0.0], dtype=np.float32),
|
||||||
|
SampleBatch.T: np.array([[-1, -1, -1]], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run compute_actions_from_input_dict
|
||||||
|
actions, _, extras = policy.compute_actions_from_input_dict(
|
||||||
|
input_dict,
|
||||||
|
explore=False,
|
||||||
|
timestep=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check actions
|
||||||
|
assert actions.shape == (
|
||||||
|
1,
|
||||||
|
*action_space.shape,
|
||||||
|
), "actions has incorrect shape."
|
||||||
|
|
||||||
|
# Check extras
|
||||||
|
assert (
|
||||||
|
SampleBatch.RETURNS_TO_GO in extras
|
||||||
|
), "extras should contain returns_to_go."
|
||||||
|
assert extras[SampleBatch.RETURNS_TO_GO].shape == (
|
||||||
|
1,
|
||||||
|
), "extras['returns_to_go'] has incorrect shape."
|
||||||
|
assert np.isclose(
|
||||||
|
extras[SampleBatch.RETURNS_TO_GO],
|
||||||
|
np.asarray([config["target_return"]], dtype=np.float32),
|
||||||
|
), "extras['returns_to_go'] should contain target_return."
|
||||||
|
|
||||||
|
# input_dict for non-initial observation
|
||||||
|
input_dict = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[[0.0], [0.0], [1.0]]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([[0, 0, 1]], dtype=np.int32)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[[0.0, 0.0, config["target_return"]]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.REWARDS: np.array([10.0], dtype=np.float32),
|
||||||
|
SampleBatch.T: np.array([[-1, -1, 0]], dtype=np.int32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run compute_actions_from_input_dict
|
||||||
|
actions, _, extras = policy.compute_actions_from_input_dict(
|
||||||
|
input_dict,
|
||||||
|
explore=False,
|
||||||
|
timestep=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check actions
|
||||||
|
assert actions.shape == (
|
||||||
|
1,
|
||||||
|
*action_space.shape,
|
||||||
|
), "actions has incorrect shape."
|
||||||
|
|
||||||
|
# Check extras
|
||||||
|
assert (
|
||||||
|
SampleBatch.RETURNS_TO_GO in extras
|
||||||
|
), "extras should contain returns_to_go."
|
||||||
|
assert extras[SampleBatch.RETURNS_TO_GO].shape == (
|
||||||
|
1,
|
||||||
|
), "extras['returns_to_go'] has incorrect shape."
|
||||||
|
assert np.isclose(
|
||||||
|
extras[SampleBatch.RETURNS_TO_GO],
|
||||||
|
np.asarray([config["target_return"] - 10.0], dtype=np.float32),
|
||||||
|
), "extras['returns_to_go'] should contain target_return."
|
||||||
|
|
||||||
|
def test_loss(self):
|
||||||
|
"""Test loss function."""
|
||||||
|
config = _default_config()
|
||||||
|
config["embed_pdrop"] = 0
|
||||||
|
config["resid_pdrop"] = 0
|
||||||
|
config["attn_pdrop"] = 0
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||||
|
action_spaces = [
|
||||||
|
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||||
|
gym.spaces.Discrete(4),
|
||||||
|
]
|
||||||
|
|
||||||
|
for action_space in action_spaces:
|
||||||
|
# Create policy
|
||||||
|
policy = DTTorchPolicy(observation_space, action_space, config)
|
||||||
|
|
||||||
|
# Run loss functions on batches with different items in the mask to make
|
||||||
|
# sure the masks are working and making the loss the same.
|
||||||
|
batch1 = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([[0, 0, 1, 3]], dtype=np.int64)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32),
|
||||||
|
SampleBatch.ATTENTION_MASKS: np.array(
|
||||||
|
[[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
batch2 = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[1.0, 1.0, -1.0],
|
||||||
|
[1.0, 10.0, 12.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[[1.0], [-0.5], [1.0], [0.5]]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([[2, 1, 1, 3]], dtype=np.int64)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[[[200.0], [-10.0], [100.0], [90.0], [80.0]]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.T: np.array([[9, 3, 0, 1]], dtype=np.int32),
|
||||||
|
SampleBatch.ATTENTION_MASKS: np.array(
|
||||||
|
[[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
loss1 = policy.loss(policy.model, policy.dist_class, batch1)
|
||||||
|
loss2 = policy.loss(policy.model, policy.dist_class, batch2)
|
||||||
|
|
||||||
|
loss1 = loss1.detach().cpu().item()
|
||||||
|
loss2 = loss2.detach().cpu().item()
|
||||||
|
|
||||||
|
assert np.isclose(loss1, loss2), "Masks are not working for losses."
|
||||||
|
|
||||||
|
# Run loss on a widely different batch and make sure the loss is different.
|
||||||
|
batch3 = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[1.0, 1.0, -20.0],
|
||||||
|
[0.1, 10.0, 12.0],
|
||||||
|
[1.4, 12.0, -9.0],
|
||||||
|
[6.0, 40.0, -2.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[[2.0], [-1.5], [0.2], [0.1]]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([[1, 3, 0, 2]], dtype=np.int64)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[[[90.0], [80.0], [70.0], [60.0], [50.0]]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.T: np.array([[3, 4, 5, 6]], dtype=np.int32),
|
||||||
|
SampleBatch.ATTENTION_MASKS: np.array(
|
||||||
|
[[1.0, 1.0, 1.0, 1.0]], dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
loss3 = policy.loss(policy.model, policy.dist_class, batch3)
|
||||||
|
loss3 = loss3.detach().cpu().item()
|
||||||
|
|
||||||
|
assert not np.isclose(
|
||||||
|
loss1, loss3
|
||||||
|
), "Widely different inputs are giving the same loss value."
|
||||||
|
|
||||||
|
def test_loss_coef(self):
|
||||||
|
"""Test the loss_coef_{key} config options."""
|
||||||
|
|
||||||
|
config = _default_config()
|
||||||
|
config["embed_pdrop"] = 0
|
||||||
|
config["resid_pdrop"] = 0
|
||||||
|
config["attn_pdrop"] = 0
|
||||||
|
# set initial action coef to 0
|
||||||
|
config["loss_coef_actions"] = 0
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
|
||||||
|
action_spaces = [
|
||||||
|
gym.spaces.Box(-1.0, 1.0, shape=(1,)),
|
||||||
|
gym.spaces.Discrete(4),
|
||||||
|
]
|
||||||
|
|
||||||
|
for action_space in action_spaces:
|
||||||
|
batch = SampleBatch(
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
),
|
||||||
|
SampleBatch.ACTIONS: (
|
||||||
|
np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32)
|
||||||
|
if isinstance(action_space, gym.spaces.Box)
|
||||||
|
else np.array([[0, 0, 1, 3]], dtype=np.int64)
|
||||||
|
),
|
||||||
|
SampleBatch.RETURNS_TO_GO: np.array(
|
||||||
|
[[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32
|
||||||
|
),
|
||||||
|
SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32),
|
||||||
|
SampleBatch.ATTENTION_MASKS: np.array(
|
||||||
|
[[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = [SampleBatch.ACTIONS, SampleBatch.OBS, SampleBatch.RETURNS_TO_GO]
|
||||||
|
for key in keys:
|
||||||
|
# create policy and run loss with different coefs
|
||||||
|
# create policy 1 with coef = 1
|
||||||
|
config1 = config.copy()
|
||||||
|
config1[f"loss_coef_{key}"] = 1.0
|
||||||
|
policy1 = DTTorchPolicy(observation_space, action_space, config1)
|
||||||
|
|
||||||
|
loss1 = policy1.loss(policy1.model, policy1.dist_class, batch)
|
||||||
|
loss1 = loss1.detach().cpu().item()
|
||||||
|
|
||||||
|
# create policy 2 with coef = 10
|
||||||
|
config2 = config.copy()
|
||||||
|
config2[f"loss_coef_{key}"] = 10.0
|
||||||
|
policy2 = DTTorchPolicy(observation_space, action_space, config2)
|
||||||
|
# copy the weights over so they output the same loss without scaling
|
||||||
|
policy2.set_state(policy1.get_state())
|
||||||
|
policy2.set_weights(policy1.get_weights())
|
||||||
|
|
||||||
|
loss2 = policy2.loss(policy2.model, policy2.dist_class, batch)
|
||||||
|
loss2 = loss2.detach().cpu().item()
|
||||||
|
|
||||||
|
# compare loss, should be factor of 10 difference
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
loss2 / loss1,
|
||||||
|
10.0,
|
||||||
|
places=3,
|
||||||
|
msg="the two losses should be different to a factor of 10.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue