diff --git a/rllib/BUILD b/rllib/BUILD index 10ec2625e..249b2c98e 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -921,6 +921,13 @@ py_test( srcs = ["algorithms/dt/tests/test_dt_model.py"] ) +py_test( + name = "test_dt_policy", + tags = ["team:rllib", "algorithms_dir"], + size = "medium", + srcs = ["algorithms/dt/tests/test_dt_policy.py"] +) + # ES py_test( name = "test_es", diff --git a/rllib/algorithms/dt/dt_torch_policy.py b/rllib/algorithms/dt/dt_torch_policy.py new file mode 100644 index 000000000..01bd6a4cd --- /dev/null +++ b/rllib/algorithms/dt/dt_torch_policy.py @@ -0,0 +1,581 @@ +import gym +import numpy as np + +from typing import ( + Dict, + List, + Tuple, + Type, + Union, + Optional, + Any, + TYPE_CHECKING, +) + +import tree +from gym.spaces import Discrete, Box + +from ray.rllib.algorithms.dt.dt_torch_model import DTTorchModel +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.mingpt import configure_gpt_optimizer +from ray.rllib.models.torch.torch_action_dist import ( + TorchDistributionWrapper, + TorchCategorical, + TorchDeterministic, +) +from ray.rllib.policy.torch_mixins import LearningRateSchedule +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.torch_utils import apply_grad_clipping +from ray.rllib.utils.typing import ( + TrainerConfigDict, + TensorType, + TensorStructType, + TensorShape, +) + +if TYPE_CHECKING: + from ray.rllib.evaluation import Episode # noqa + +torch, nn = try_import_torch() +F = nn.functional + + +class DTTorchPolicy(LearningRateSchedule, TorchPolicyV2): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, + ): + LearningRateSchedule.__init__( + self, + config["lr"], + config["lr_schedule"], + ) + + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) + + @override(TorchPolicyV2) + def make_model_and_action_dist( + self, + ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]: + # Model + model_config = self.config["model"] + # TODO: make these better with better AlgorithmConfig options. + model_config.update( + embed_dim=self.config["embed_dim"], + max_ep_len=self.config["horizon"], + num_layers=self.config["num_layers"], + num_heads=self.config["num_heads"], + embed_pdrop=self.config["embed_pdrop"], + resid_pdrop=self.config["resid_pdrop"], + attn_pdrop=self.config["attn_pdrop"], + use_obs_output=self.config.get("loss_coef_obs", 0) > 0, + use_return_output=self.config.get("loss_coef_returns_to_go", 0) > 0, + ) + + num_outputs = int(np.product(self.observation_space.shape)) + + model = ModelCatalog.get_model_v2( + obs_space=self.observation_space, + action_space=self.action_space, + num_outputs=num_outputs, + model_config=model_config, + framework=self.config["framework"], + model_interface=None, + default_model=DTTorchModel, + name="model", + ) + + # Action Distribution + if isinstance(self.action_space, Discrete): + action_dist = TorchCategorical + elif isinstance(self.action_space, Box): + action_dist = TorchDeterministic + else: + raise NotImplementedError + + return model, action_dist + + @override(TorchPolicyV2) + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + optimizer = configure_gpt_optimizer( + model=self.model, + learning_rate=self.config["lr"], + weight_decay=self.config["optimizer"]["weight_decay"], + betas=self.config["optimizer"]["betas"], + ) + + return optimizer + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, + episode: Optional["Episode"] = None, + ) -> SampleBatch: + """Called by offline data reader after loading in one episode. + Adds a done flag at the end of trajectory so that SegmentationBuffer can + split using the done flag to avoid duplicate trajectories. + """ + ep_len = sample_batch.env_steps() + sample_batch[SampleBatch.DONES] = np.array([False] * (ep_len - 1) + [True]) + return sample_batch + + @PublicAPI + def get_initial_input_dict(self, observation: TensorStructType) -> SampleBatch: + """Get the initial input_dict to be passed into compute_single_action. + + Args: + observation: first (unbatched) observation from env.reset() + + Returns: + The input_dict for inference: { + OBS: [max_seq_len, obs_dim] array, + ACTIONS: [max_seq_len - 1, act_dim] array, + RETURNS_TO_GO: [max_seq_len - 1] array, + REWARDS: scalar, + TIMESTEPS: [max_seq_len - 1] array, + } + Note the sequence lengths are different, and is specified as per + view_requirements. Explanations in action_distribution_fn method. + """ + observation = convert_to_numpy(observation) + obs_shape = observation.shape + obs_dtype = observation.dtype + + act_shape = self.action_space.shape + act_dtype = self.action_space.dtype + + # Here we will pad all the required inputs to its proper sequence length + # as their ViewRequirement. + + observations = np.concatenate( + [ + np.zeros((self.max_seq_len - 1, *obs_shape), dtype=obs_dtype), + observation[None], + ], + axis=0, + ) + + actions = np.zeros((self.max_seq_len - 1, *act_shape), dtype=act_dtype) + + rtg = np.zeros(self.max_seq_len - 1, dtype=np.float32) + + rewards = np.zeros((), dtype=np.float32) + + # -1 for masking in action_distribution_fn + timesteps = np.full(self.max_seq_len - 1, fill_value=-1, dtype=np.int32) + + input_dict = SampleBatch( + { + SampleBatch.OBS: observations, + SampleBatch.ACTIONS: actions, + SampleBatch.RETURNS_TO_GO: rtg, + SampleBatch.REWARDS: rewards, + SampleBatch.T: timesteps, + } + ) + return input_dict + + @PublicAPI + def get_next_input_dict( + self, + input_dict: SampleBatch, + action: TensorStructType, + reward: TensorStructType, + next_obs: TensorStructType, + extra: Dict[str, TensorType], + ) -> SampleBatch: + """Returns a new input_dict after stepping through the environment once. + + Args: + input_dict: the input dict passed into compute_single_action. + action: the (unbatched) action taken this step. + reward: the (unbatched) reward from env.step + next_obs: the (unbatached) next observation from env.step + extra: the extra action out from compute_single_action. + In this case contains current returns to go *before* the current + reward is subtracted from target_return. + + Returns: + A new input_dict to be passed into compute_single_action. + The input_dict for inference: { + OBS: [max_seq_len, obs_dim] array, + ACTIONS: [max_seq_len - 1, act_dim] array, + RETURNS_TO_GO: [max_seq_len - 1] array, + REWARDS: scalar, + TIMESTEPS: [max_seq_len - 1] array, + } + Note the sequence lengths are different, and is specified as per + view_requirements. Explanations in action_distribution_fn method. + """ + # creates a copy of input_dict with only numpy arrays + input_dict = tree.map_structure(convert_to_numpy, input_dict) + # convert everything else to numpy as well + action, reward, next_obs, extra = convert_to_numpy( + (action, reward, next_obs, extra) + ) + + # check dimensions + assert input_dict[SampleBatch.OBS].shape == ( + self.max_seq_len, + *self.observation_space.shape, + ) + assert input_dict[SampleBatch.ACTIONS].shape == ( + self.max_seq_len - 1, + *self.action_space.shape, + ) + assert input_dict[SampleBatch.RETURNS_TO_GO].shape == (self.max_seq_len - 1,) + assert input_dict[SampleBatch.T].shape == (self.max_seq_len - 1,) + + # Shift observations + input_dict[SampleBatch.OBS] = np.concatenate( + [ + input_dict[SampleBatch.OBS][1:], + next_obs[None], + ], + axis=0, + ) + + # Shift actions + input_dict[SampleBatch.ACTIONS] = np.concatenate( + [ + input_dict[SampleBatch.ACTIONS][1:], + action[None], + ], + axis=0, + ) + + # Reward is not a sequence, it's only used to calculate next rtg. + input_dict[SampleBatch.REWARDS] = np.asarray(reward) + + # See methods action_distribution_fn and extra_action_out for an explanation + # of why this is done. + input_dict[SampleBatch.RETURNS_TO_GO] = np.concatenate( + [ + input_dict[SampleBatch.RETURNS_TO_GO][1:], + np.asarray(extra[SampleBatch.RETURNS_TO_GO])[None], + ], + axis=0, + ) + + # Shift and increment timesteps + input_dict[SampleBatch.T] = np.concatenate( + [ + input_dict[SampleBatch.T][1:], + input_dict[SampleBatch.T][-1:] + 1, + ], + axis=0, + ) + + return input_dict + + @DeveloperAPI + def get_initial_rtg_tensor( + self, + shape: TensorShape, + dtype: Optional[Type] = torch.float32, + device: Optional["torch.device"] = None, + ): + """Returns a initial/target returns-to-go tensor of the given shape. + + Args: + shape: Shape of the rtg tensor. + dtype: Type of the data in the tensor. Defaults to torch.float32. + device: The device this tensor should be on. Defaults to self.device. + """ + if device is None: + device = self.device + if dtype is None: + device = torch.float32 + + assert self.config["target_return"] is not None, "Must specify target_return." + initial_rtg = torch.full( + shape, + fill_value=self.config["target_return"], + dtype=dtype, + device=device, + ) + return initial_rtg + + @override(TorchPolicyV2) + @DeveloperAPI + def compute_actions( + self, + *args, + **kwargs, + ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]: + raise ValueError("Please use compute_actions_from_input_dict instead.") + + @override(TorchPolicyV2) + def compute_actions_from_input_dict( + self, + input_dict: SampleBatch, + explore: bool = None, + timestep: Optional[int] = None, + **kwargs, + ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + """ + Args: + input_dict: input_dict (that contains a batch dimension for each value). + Keys and shapes: { + OBS: [batch_size, max_seq_len, obs_dim], + ACTIONS: [batch_size, max_seq_len - 1, act_dim], + RETURNS_TO_GO: [batch_size, max_seq_len - 1], + REWARDS: [batch_size], + TIMESTEPS: [batch_size, max_seq_len - 1], + } + explore: unused. + timestep: unused. + Returns: + A tuple consisting of a) actions, b) state_out, c) extra_fetches. + """ + with torch.no_grad(): + # Pass lazy (torch) tensor dict to Model as `input_dict`. + input_dict = input_dict.copy() + input_dict = self._lazy_tensor_dict(input_dict) + input_dict.set_training(True) + + actions, state_out, extra_fetches = self._compute_action_helper(input_dict) + return actions, state_out, extra_fetches + + # TODO: figure out what this with_lock does and why it's only on the helper method. + @with_lock + @override(TorchPolicyV2) + def _compute_action_helper(self, input_dict): + # Switch to eval mode. + self.model.eval() + + batch_size = input_dict[SampleBatch.OBS].shape[0] + + # NOTE: This is probably the most confusing part of the code, made to work with + # env_runner and SimpleListCollector during evaluation, and thus should + # be changed for the new Policy and Connector API. + # So I'll explain how it works. + + # Add current timestep (+1 because -1 is first observation) + # NOTE: ViewRequirement of timestep is -(max_seq_len-2):0. + # The wierd limits is because RLlib treats initial obs as time -1, + # and then 0 is (act, rew, next_obs), etc. + # So we only collect max_seq_len-1 from the rollout and create the current + # step here by adding 1. + # Decision transformer treats initial observation as timestep 0, giving us + # 0 is (obs, act, rew). + timesteps = input_dict[SampleBatch.T] + new_timestep = timesteps[:, -1:] + 1 + input_dict[SampleBatch.T] = torch.cat([timesteps, new_timestep], dim=1) + + # mask out any padded value at start of rollout + # NOTE: the other reason for doing this is that evaluation rollout front + # pads timesteps with -1, so using this we can find out when we need to mask + # out the front section of the batch. + input_dict[SampleBatch.ATTENTION_MASKS] = torch.where( + input_dict[SampleBatch.T] >= 0, 1.0, 0.0 + ) + + # Remove out-of-bound -1 timesteps after attention mask is calculated + uncliped_timesteps = input_dict[SampleBatch.T] + input_dict[SampleBatch.T] = torch.where( + uncliped_timesteps < 0, + torch.zeros_like(uncliped_timesteps), + uncliped_timesteps, + ) + + # Computes returns-to-go. + # NOTE: There are two rtg calculations: updated_rtg and initial_rtg. + # updated_rtg takes the previous rtg value (the ViewRequirement is + # -(max_seq_len-1):-1), and subtracts the last reward from it. + rtg = input_dict[SampleBatch.RETURNS_TO_GO] + last_rtg = rtg[:, -1] + last_reward = input_dict[SampleBatch.REWARDS] + updated_rtg = last_rtg - last_reward + # initial_rtg simply is filled with target_return. + # These two are both only for the current timestep. + initial_rtg = self.get_initial_rtg_tensor( + (batch_size, 1), dtype=rtg.dtype, device=rtg.device + ) + + # Then based on whether we are currently at the first timestep or not + # we use the initial_rtg or updated_rtg. + new_rtg = torch.where(new_timestep == 0, initial_rtg, updated_rtg[:, None]) + # Append the new_rtg to the batch. + input_dict[SampleBatch.RETURNS_TO_GO] = torch.cat([rtg, new_rtg], dim=1)[ + ..., None + ] + + # Pad current action (is not actually attended to and used during inference) + past_actions = input_dict[SampleBatch.ACTIONS] + action_pad = torch.zeros( + (batch_size, 1, *past_actions.shape[2:]), + dtype=past_actions.dtype, + device=past_actions.device, + ) + input_dict[SampleBatch.ACTIONS] = torch.cat([past_actions, action_pad], dim=1) + + # Run inference on model + model_out, _ = self.model(input_dict) # noop, just returns obs. + preds = self.model.get_prediction(model_out, input_dict) + dist_inputs = preds[SampleBatch.ACTIONS][:, -1] + + # Get the actions from the action_dist. + action_dist = self.dist_class(dist_inputs, self.model) + actions = action_dist.deterministic_sample() + + # This is used by env_runner and is actually how it adds custom keys to + # SimpleListCollector and allows ViewRequirements to work. + # This is also used in user inference in get_next_input_dict, which takes + # this output as one of the input. + extra_fetches = { + # new_rtg still has the leftover extra 3rd dimension for inference + SampleBatch.RETURNS_TO_GO: new_rtg.squeeze(-1), + SampleBatch.ACTION_DIST_INPUTS: dist_inputs, + } + + # Update our global timestep by the batch size. + self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) + + return convert_to_numpy((actions, [], extra_fetches)) + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Loss function. + + Args: + model: The ModelV2 to run foward pass on. + dist_class: The distribution of this policy. + train_batch: Training SampleBatch. + Keys and shapes: { + OBS: [batch_size, max_seq_len, obs_dim], + ACTIONS: [batch_size, max_seq_len, act_dim], + RETURNS_TO_GO: [batch_size, max_seq_len + 1, 1], + TIMESTEPS: [batch_size, max_seq_len], + ATTENTION_MASKS: [batch_size, max_seq_len], + } + Returns: + Loss scalar tensor. + """ + train_batch = self._lazy_tensor_dict(train_batch) + + # policy forward and get prediction + model_out, _ = self.model(train_batch) # noop, just returns obs. + preds = self.model.get_prediction(model_out, train_batch) + + # get the regression targets + targets = self.model.get_targets(model_out, train_batch) + + # get the attention masks for masked-loss + masks = train_batch[SampleBatch.ATTENTION_MASKS] + + # compute loss + loss = self._masked_loss(preds, targets, masks) + + self.log("cur_lr", torch.tensor(self.cur_lr)) + + return loss + + def _masked_loss(self, preds, targets, masks): + losses = [] + for key in targets: + assert ( + key in preds + ), "for target {key} there is no prediction from the output of the model" + loss_coef = self.config.get(f"loss_coef_{key}", 1.0) + if self._is_discrete(key): + loss = loss_coef * self._masked_cross_entropy_loss( + preds[key], targets[key], masks + ) + else: + loss = loss_coef * self._masked_mse_loss( + preds[key], targets[key], masks + ) + + losses.append(loss) + self.log(f"{key}_loss", loss) + + return sum(losses) + + def _is_discrete(self, key): + return key == SampleBatch.ACTIONS and isinstance(self.action_space, Discrete) + + def _masked_cross_entropy_loss( + self, + preds: TensorType, + targets: TensorType, + masks: TensorType, + ) -> TensorType: + """Computes cross-entropy loss between preds and targets, subject to a mask. + + Args: + preds: logits of shape [B1, ..., Bn, M] + targets: index targets for preds of shape [B1, ..., Bn] + masks: 0 means don't compute loss, 1 means compute loss + shape [B1, ..., Bn] + + Returns: + Scalar cross entropy loss. + """ + losses = F.cross_entropy( + preds.reshape(-1, preds.shape[-1]), targets.reshape(-1), reduction="none" + ) + losses = losses * masks.reshape(-1) + return losses.mean() + + def _masked_mse_loss( + self, + preds: TensorType, + targets: TensorType, + masks: TensorType, + ) -> TensorType: + """Computes MSE loss between preds and targets, subject to a mask. + + Args: + preds: logits of shape [B1, ..., Bn, M] + targets: index targets for preds of shape [B1, ..., Bn] + masks: 0 means don't compute loss, 1 means compute loss + shape [B1, ..., Bn] + + Returns: + Scalar cross entropy loss. + """ + losses = F.mse_loss(preds, targets, reduction="none") + losses = losses * masks.reshape( + *masks.shape, *([1] * (len(preds.shape) - len(masks.shape))) + ) + return losses.mean() + + @override(TorchPolicyV2) + def extra_grad_process(self, local_optimizer, loss): + return apply_grad_clipping(self, local_optimizer, loss) + + def log(self, key, value): + # internal log function + self.model.tower_stats[key] = value + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + stats_dict = { + k: torch.stack(self.get_tower_stats(k)).mean().item() + for k in self.model.tower_stats + } + return stats_dict diff --git a/rllib/algorithms/dt/tests/test_dt_policy.py b/rllib/algorithms/dt/tests/test_dt_policy.py new file mode 100644 index 000000000..2e1048336 --- /dev/null +++ b/rllib/algorithms/dt/tests/test_dt_policy.py @@ -0,0 +1,507 @@ +import unittest +from typing import Dict + +import gym +import numpy as np + +import ray +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.algorithms.dt.dt_torch_policy import DTTorchPolicy + +tf1, tf, tfv = try_import_tf() +torch, nn = try_import_torch() + + +def _default_config(): + """Base config to use.""" + return { + "model": { + "max_seq_len": 4, + }, + "embed_dim": 32, + "num_layers": 2, + "horizon": 10, + "num_heads": 2, + "embed_pdrop": 0.1, + "resid_pdrop": 0.1, + "attn_pdrop": 0.1, + "framework": "torch", + "lr": 1e-3, + "lr_schedule": None, + "optimizer": { + "weight_decay": 1e-4, + "betas": [0.9, 0.99], + }, + "target_return": 200.0, + "loss_coef_actions": 1.0, + "loss_coef_obs": 0, + "loss_coef_returns_to_go": 0, + "num_gpus": 0, + "_fake_gpus": None, + } + + +def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]): + for key in d1.keys(): + assert key in d2.keys() + + for key in d2.keys(): + assert key in d1.keys() + + for key in d1.keys(): + assert isinstance(d1[key], np.ndarray), "input_dict should only be numpy array." + assert isinstance(d2[key], np.ndarray), "input_dict should only be numpy array." + assert d1[key].shape == d2[key].shape, "input_dict are of different shape." + assert np.allclose(d1[key], d2[key]), "input_dict values are not equal." + + +class TestDTPolicy(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_torch_postprocess_trajectory(self): + """Test postprocess_trajectory""" + config = _default_config() + + observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,)) + action_space = gym.spaces.Box(-1.0, 1.0, shape=(3,)) + + # Create policy + policy = DTTorchPolicy(observation_space, action_space, config) + + # Generate input_dict with some data + sample_batch = SampleBatch( + { + SampleBatch.REWARDS: np.array([1.0, 2.0, 1.0, 1.0]), + SampleBatch.EPS_ID: np.array([0, 0, 0, 0]), + } + ) + + # Do postprocess trajectory to calculate rtg + sample_batch = policy.postprocess_trajectory(sample_batch) + + # Assert that dones is correctly set + assert SampleBatch.DONES in sample_batch, "`dones` isn't part of the batch." + assert np.allclose( + sample_batch[SampleBatch.DONES], + np.array([False, False, False, True]), + ), "`dones` isn't set correctly." + + def test_torch_input_dict(self): + """Test inference input_dict methods + + This is a minimal version the test in test_dt.py. + The shapes of the input_dict might be confusing but it makes sense in + context of what the function is supposed to do. + Check action_distribution_fn for an explanation. + """ + config = _default_config() + + observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,)) + action_spaces = [ + gym.spaces.Box(-1.0, 1.0, shape=(1,)), + gym.spaces.Discrete(4), + ] + + for action_space in action_spaces: + # Create policy + policy = DTTorchPolicy(observation_space, action_space, config) + + # initial obs and input_dict + obs = np.array([0.0, 1.0, 2.0]) + input_dict = policy.get_initial_input_dict(obs) + + # Check input_dict matches what it should be + target_input_dict = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[0.0], [0.0], [0.0]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([0, 0, 0], dtype=np.int32) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [0.0, 0.0, 0.0], dtype=np.float32 + ), + SampleBatch.REWARDS: np.zeros((), dtype=np.float32), + SampleBatch.T: np.array([-1, -1, -1], dtype=np.int32), + } + ) + _assert_input_dict_equals(input_dict, target_input_dict) + + # Get next input_dict + input_dict = policy.get_next_input_dict( + input_dict, + action=( + np.asarray([1.0], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.asarray(1, dtype=np.int32) + ), + reward=1.0, + next_obs=np.array([3.0, 4.0, 5.0]), + extra={ + SampleBatch.RETURNS_TO_GO: config["target_return"], + }, + ) + + # Check input_dict matches what it should be + target_input_dict = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[0.0], [0.0], [1.0]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([0, 0, 1], dtype=np.int32) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [0.0, 0.0, config["target_return"]], dtype=np.float32 + ), + SampleBatch.REWARDS: np.asarray(1.0, dtype=np.float32), + SampleBatch.T: np.array([-1, -1, 0], dtype=np.int32), + } + ) + _assert_input_dict_equals(input_dict, target_input_dict) + + def test_torch_action(self): + """Test policy's action_distribution_fn and extra_action_out methods by + calling compute_actions_from_input_dict which works those two methods + in conjunction. + """ + config = _default_config() + + observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,)) + action_spaces = [ + gym.spaces.Box(-1.0, 1.0, shape=(1,)), + gym.spaces.Discrete(4), + ] + + for action_space in action_spaces: + # Create policy + policy = DTTorchPolicy(observation_space, action_space, config) + + # input_dict for initial observation + input_dict = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + ] + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[[0.0], [0.0], [0.0]]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([[0, 0, 0]], dtype=np.int32) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [[0.0, 0.0, 0.0]], dtype=np.float32 + ), + SampleBatch.REWARDS: np.array([0.0], dtype=np.float32), + SampleBatch.T: np.array([[-1, -1, -1]], dtype=np.int32), + } + ) + + # Run compute_actions_from_input_dict + actions, _, extras = policy.compute_actions_from_input_dict( + input_dict, + explore=False, + timestep=None, + ) + + # Check actions + assert actions.shape == ( + 1, + *action_space.shape, + ), "actions has incorrect shape." + + # Check extras + assert ( + SampleBatch.RETURNS_TO_GO in extras + ), "extras should contain returns_to_go." + assert extras[SampleBatch.RETURNS_TO_GO].shape == ( + 1, + ), "extras['returns_to_go'] has incorrect shape." + assert np.isclose( + extras[SampleBatch.RETURNS_TO_GO], + np.asarray([config["target_return"]], dtype=np.float32), + ), "extras['returns_to_go'] should contain target_return." + + # input_dict for non-initial observation + input_dict = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + ] + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[[0.0], [0.0], [1.0]]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([[0, 0, 1]], dtype=np.int32) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [[0.0, 0.0, config["target_return"]]], dtype=np.float32 + ), + SampleBatch.REWARDS: np.array([10.0], dtype=np.float32), + SampleBatch.T: np.array([[-1, -1, 0]], dtype=np.int32), + } + ) + + # Run compute_actions_from_input_dict + actions, _, extras = policy.compute_actions_from_input_dict( + input_dict, + explore=False, + timestep=None, + ) + + # Check actions + assert actions.shape == ( + 1, + *action_space.shape, + ), "actions has incorrect shape." + + # Check extras + assert ( + SampleBatch.RETURNS_TO_GO in extras + ), "extras should contain returns_to_go." + assert extras[SampleBatch.RETURNS_TO_GO].shape == ( + 1, + ), "extras['returns_to_go'] has incorrect shape." + assert np.isclose( + extras[SampleBatch.RETURNS_TO_GO], + np.asarray([config["target_return"] - 10.0], dtype=np.float32), + ), "extras['returns_to_go'] should contain target_return." + + def test_loss(self): + """Test loss function.""" + config = _default_config() + config["embed_pdrop"] = 0 + config["resid_pdrop"] = 0 + config["attn_pdrop"] = 0 + + observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,)) + action_spaces = [ + gym.spaces.Box(-1.0, 1.0, shape=(1,)), + gym.spaces.Discrete(4), + ] + + for action_space in action_spaces: + # Create policy + policy = DTTorchPolicy(observation_space, action_space, config) + + # Run loss functions on batches with different items in the mask to make + # sure the masks are working and making the loss the same. + batch1 = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + ] + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([[0, 0, 1, 3]], dtype=np.int64) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32 + ), + SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32), + SampleBatch.ATTENTION_MASKS: np.array( + [[0.0, 0.0, 1.0, 1.0]], dtype=np.float32 + ), + } + ) + + batch2 = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [ + [1.0, 1.0, -1.0], + [1.0, 10.0, 12.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + ] + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[[1.0], [-0.5], [1.0], [0.5]]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([[2, 1, 1, 3]], dtype=np.int64) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [[[200.0], [-10.0], [100.0], [90.0], [80.0]]], dtype=np.float32 + ), + SampleBatch.T: np.array([[9, 3, 0, 1]], dtype=np.int32), + SampleBatch.ATTENTION_MASKS: np.array( + [[0.0, 0.0, 1.0, 1.0]], dtype=np.float32 + ), + } + ) + + loss1 = policy.loss(policy.model, policy.dist_class, batch1) + loss2 = policy.loss(policy.model, policy.dist_class, batch2) + + loss1 = loss1.detach().cpu().item() + loss2 = loss2.detach().cpu().item() + + assert np.isclose(loss1, loss2), "Masks are not working for losses." + + # Run loss on a widely different batch and make sure the loss is different. + batch3 = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [ + [1.0, 1.0, -20.0], + [0.1, 10.0, 12.0], + [1.4, 12.0, -9.0], + [6.0, 40.0, -2.0], + ] + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[[2.0], [-1.5], [0.2], [0.1]]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([[1, 3, 0, 2]], dtype=np.int64) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [[[90.0], [80.0], [70.0], [60.0], [50.0]]], dtype=np.float32 + ), + SampleBatch.T: np.array([[3, 4, 5, 6]], dtype=np.int32), + SampleBatch.ATTENTION_MASKS: np.array( + [[1.0, 1.0, 1.0, 1.0]], dtype=np.float32 + ), + } + ) + + loss3 = policy.loss(policy.model, policy.dist_class, batch3) + loss3 = loss3.detach().cpu().item() + + assert not np.isclose( + loss1, loss3 + ), "Widely different inputs are giving the same loss value." + + def test_loss_coef(self): + """Test the loss_coef_{key} config options.""" + + config = _default_config() + config["embed_pdrop"] = 0 + config["resid_pdrop"] = 0 + config["attn_pdrop"] = 0 + # set initial action coef to 0 + config["loss_coef_actions"] = 0 + + observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,)) + action_spaces = [ + gym.spaces.Box(-1.0, 1.0, shape=(1,)), + gym.spaces.Discrete(4), + ] + + for action_space in action_spaces: + batch = SampleBatch( + { + SampleBatch.OBS: np.array( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + ] + ], + dtype=np.float32, + ), + SampleBatch.ACTIONS: ( + np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32) + if isinstance(action_space, gym.spaces.Box) + else np.array([[0, 0, 1, 3]], dtype=np.int64) + ), + SampleBatch.RETURNS_TO_GO: np.array( + [[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32 + ), + SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32), + SampleBatch.ATTENTION_MASKS: np.array( + [[0.0, 0.0, 1.0, 1.0]], dtype=np.float32 + ), + } + ) + + keys = [SampleBatch.ACTIONS, SampleBatch.OBS, SampleBatch.RETURNS_TO_GO] + for key in keys: + # create policy and run loss with different coefs + # create policy 1 with coef = 1 + config1 = config.copy() + config1[f"loss_coef_{key}"] = 1.0 + policy1 = DTTorchPolicy(observation_space, action_space, config1) + + loss1 = policy1.loss(policy1.model, policy1.dist_class, batch) + loss1 = loss1.detach().cpu().item() + + # create policy 2 with coef = 10 + config2 = config.copy() + config2[f"loss_coef_{key}"] = 10.0 + policy2 = DTTorchPolicy(observation_space, action_space, config2) + # copy the weights over so they output the same loss without scaling + policy2.set_state(policy1.get_state()) + policy2.set_weights(policy1.get_weights()) + + loss2 = policy2.loss(policy2.model, policy2.dist_class, batch) + loss2 = loss2.detach().cpu().item() + + # compare loss, should be factor of 10 difference + self.assertAlmostEqual( + loss2 / loss1, + 10.0, + places=3, + msg="the two losses should be different to a factor of 10.", + ) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__]))