From b2bcab711d333442c282cf64c66a9fac2c93218f Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 21 Dec 2020 02:22:32 +0100 Subject: [PATCH] [RLlib] Attention Nets: tf (#12753) --- rllib/BUILD | 7 - rllib/agents/callbacks.py | 4 +- rllib/agents/ppo/ppo_tf_policy.py | 3 +- .../collectors/simple_list_collector.py | 171 +++++++++--- rllib/evaluation/postprocessing.py | 4 - rllib/examples/attention_net.py | 10 +- rllib/examples/cartpole_lstm.py | 2 +- rllib/models/modelv2.py | 69 +++-- rllib/models/tests/test_attention_nets.py | 263 ------------------ rllib/models/tf/attention_net.py | 137 +++++---- rllib/models/tf/layers/__init__.py | 4 +- .../layers/relative_multi_head_attention.py | 45 ++- rllib/models/tf/layers/skip_connection.py | 1 - rllib/models/torch/modules/skip_connection.py | 1 - rllib/policy/dynamic_tf_policy.py | 26 +- rllib/policy/eager_tf_policy.py | 16 +- rllib/policy/policy.py | 53 +++- rllib/policy/rnn_sequencing.py | 125 +++++---- rllib/policy/sample_batch.py | 43 ++- rllib/policy/tf_policy.py | 32 ++- rllib/policy/torch_policy.py | 3 +- rllib/policy/view_requirement.py | 20 +- rllib/tests/test_attention_net_learning.py | 6 +- rllib/tests/test_lstm.py | 51 ++-- rllib/utils/sgd.py | 29 +- rllib/utils/typing.py | 3 + 26 files changed, 567 insertions(+), 561 deletions(-) delete mode 100644 rllib/models/tests/test_attention_nets.py diff --git a/rllib/BUILD b/rllib/BUILD index bd612d0ff..c645c27a0 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1074,13 +1074,6 @@ py_test( # Tag: models # -------------------------------------------------------------------- -py_test( - name = "test_attention_nets", - tags = ["models"], - size = "small", - srcs = ["models/tests/test_attention_nets.py"] -) - py_test( name = "test_convtranspose2d_stack", tags = ["models"], diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index bf5284740..e84cf4148 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -191,8 +191,8 @@ class DefaultCallbacks: **kwargs) -> None: """Called at the beginning of Policy.learn_on_batch(). - Note: This is called before the Model's `preprocess_train_batch()` - is called. + Note: This is called before 0-padding via + `pad_batch_to_sequences_of_same_size`. Args: policy (Policy): Reference to the current Policy object. diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 29266dfcc..957d68ce3 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -198,7 +198,8 @@ def postprocess_ppo_gae( # input_dict. if policy.config["_use_trajectory_view_api"]: # Create an input dict according to the Model's requirements. - input_dict = policy.model.get_input_dict(sample_batch, index=-1) + input_dict = policy.model.get_input_dict( + sample_batch, index="last") last_r = policy._value(**input_dict) # TODO: (sven) Remove once trajectory view API is all-algo default. else: diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index efcadf32f..1d5fe3f76 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -1,6 +1,7 @@ import collections from gym.spaces import Space import logging +import math import numpy as np from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union @@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray: return arr +_INIT_COLS = [SampleBatch.OBS] + + class _AgentCollector: """Collects samples for one agent in one trajectory (episode). @@ -45,9 +49,18 @@ class _AgentCollector: _next_unroll_id = 0 # disambiguates unrolls within a single episode - def __init__(self, shift_before: int = 0): - self.shift_before = max(shift_before, 1) + def __init__(self, view_reqs): + # Determine the size of the buffer we need for data before the actual + # episode starts. This is used for 0-buffering of e.g. prev-actions, + # or internal state inputs. + self.shift_before = -min( + (int(vr.shift.split(":")[0]) + if isinstance(vr.shift, str) else vr.shift) + + (-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0) + for k, vr in view_reqs.items()) + # The actual data buffers (lists holding each timestep's data). self.buffers: Dict[str, List] = {} + # The episode ID for the agent for which we collect data. self.episode_id = None # The simple timestep count for this agent. Gets increased by one # each time a (non-initial!) observation is added. @@ -137,31 +150,88 @@ class _AgentCollector: # -> skip. if data_col not in self.buffers: continue + # OBS are already shifted by -1 (the initial obs starts one ts # before all other data columns). - shift = view_req.shift - \ - (1 if data_col == SampleBatch.OBS else 0) + obs_shift = -1 if data_col == SampleBatch.OBS else 0 + + # Keep an np-array cache so we don't have to regenerate the + # np-array for different view_cols using to the same data_col. if data_col not in np_data: np_data[data_col] = to_float_np_array(self.buffers[data_col]) - # Shift is exactly 0: Send trajectory as is. - if shift == 0: - data = np_data[data_col][self.shift_before:] - # Shift is positive: We still need to 0-pad at the end here. - elif shift > 0: - data = to_float_np_array( - self.buffers[data_col][self.shift_before + shift:] + [ - np.zeros( - shape=view_req.space.shape, - dtype=view_req.space.dtype) for _ in range(shift) + + # Range of indices on time-axis, e.g. "-50:-1". Together with + # the `batch_repeat_value`, this determines the data produced. + # Example: + # batch_repeat_value=10, shift_from=-3, shift_to=-1 + # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # resulting data=[[-3, -2, -1], [7, 8, 9]] + # Range of 3 consecutive items repeats every 10 timesteps. + if view_req.shift_from is not None: + if view_req.batch_repeat_value > 1: + count = int( + math.ceil((len(np_data[data_col]) - self.shift_before) + / view_req.batch_repeat_value)) + data = np.asarray([ + np_data[data_col][self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_from + + obs_shift:self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_to + 1 + obs_shift] + for i in range(count) ]) - # Shift is negative: Shift into the already existing and 0-padded - # "before" area of our buffers. + else: + data = np_data[data_col][self.shift_before + + view_req.shift_from + + obs_shift:self.shift_before + + view_req.shift_to + 1 + obs_shift] + # Set of (probably non-consecutive) indices. + # Example: + # shift=[-3, 0] + # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...] + elif isinstance(view_req.shift, np.ndarray): + data = np_data[data_col][self.shift_before + obs_shift + + view_req.shift] + # Single shift int value. Use the trajectory as-is, and if + # `shift` != 0: shifted by that value. else: - data = np_data[data_col][self.shift_before + shift:shift] + shift = view_req.shift + obs_shift + + # Batch repeat (only provide a value every n timesteps). + if view_req.batch_repeat_value > 1: + count = int( + math.ceil((len(np_data[data_col]) - self.shift_before) + / view_req.batch_repeat_value)) + data = np.asarray([ + np_data[data_col][self.shift_before + ( + i * view_req.batch_repeat_value) + shift] + for i in range(count) + ]) + # Shift is exactly 0: Use trajectory as is. + elif shift == 0: + data = np_data[data_col][self.shift_before:] + # Shift is positive: We still need to 0-pad at the end. + elif shift > 0: + data = to_float_np_array( + self.buffers[data_col][self.shift_before + shift:] + [ + np.zeros( + shape=view_req.space.shape, + dtype=view_req.space.dtype) + for _ in range(shift) + ]) + # Shift is negative: Shift into the already existing and + # 0-padded "before" area of our buffers. + else: + data = np_data[data_col][self.shift_before + shift:shift] + if len(data) > 0: batch_data[view_col] = data - batch = SampleBatch(batch_data) + # Due to possible batch-repeats > 1, columns in the resulting batch + # may not all have the same batch size. + batch = SampleBatch(batch_data, _dont_check_lens=True) # Add EPS_ID and UNROLL_ID to batch. batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id, @@ -230,15 +300,22 @@ class _PolicyCollector: appended to this policy's buffers. """ - def __init__(self): - """Initializes a _PolicyCollector instance.""" + def __init__(self, policy): + """Initializes a _PolicyCollector instance. + + Args: + policy (Policy): The policy object. + """ self.buffers: Dict[str, List] = collections.defaultdict(list) + self.policy = policy # The total timestep count for all agents that use this policy. # NOTE: This is not an env-step count (across n agents). AgentA and # agentB, both using this policy, acting in the same episode and both # doing n steps would increase the count by 2*n. self.agent_steps = 0 + # Seq-lens list of already added agent batches. + self.seq_lens = [] if policy.is_recurrent() else None def add_postprocessed_batch_for_training( self, batch: SampleBatch, @@ -257,11 +334,18 @@ class _PolicyCollector: # 1) If col is not in view_requirements, we must have a direct # child of the base Policy that doesn't do auto-view req creation. # 2) Col is in view-reqs and needed for training. - if view_col not in view_requirements or \ - view_requirements[view_col].used_for_training: + view_req = view_requirements.get(view_col) + if view_req is None or view_req.used_for_training: self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count + # Adjust the seq-lens array depending on the incoming agent sequences. + if self.seq_lens is not None: + max_seq_len = self.policy.config["model"]["max_seq_len"] + count = batch.count + while count > 0: + self.seq_lens.append(min(count, max_seq_len)) + count -= max_seq_len def build(self): """Builds a SampleBatch for this policy from the collected data. @@ -273,20 +357,22 @@ class _PolicyCollector: this policy. """ # Create batch from our buffers. - batch = SampleBatch(self.buffers) - assert SampleBatch.UNROLL_ID in batch.data + batch = SampleBatch( + self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) # Clear buffers for future samples. self.buffers.clear() - # Reset agent steps to 0. + # Reset agent steps to 0 and seq-lens to empty list. self.agent_steps = 0 + if self.seq_lens is not None: + self.seq_lens = [] return batch class _PolicyCollectorGroup: def __init__(self, policy_map): self.policy_collectors = { - pid: _PolicyCollector() - for pid in policy_map.keys() + pid: _PolicyCollector(policy) + for pid, policy in policy_map.items() } # Total env-steps (1 env-step=up to N agents stepped). self.env_steps = 0 @@ -396,11 +482,14 @@ class _SimpleListCollector(_SampleCollector): self.agent_key_to_policy_id[agent_key] = policy_id else: assert self.agent_key_to_policy_id[agent_key] == policy_id + policy = self.policy_map[policy_id] + view_reqs = policy.model.inference_view_requirements if \ + getattr(policy, "model", None) else policy.view_requirements # Add initial obs to Trajectory. assert agent_key not in self.agent_collectors # TODO: determine exact shift-before based on the view-req shifts. - self.agent_collectors[agent_key] = _AgentCollector() + self.agent_collectors[agent_key] = _AgentCollector(view_reqs) self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, agent_index=episode._agent_index(agent_id), @@ -466,11 +555,19 @@ class _SimpleListCollector(_SampleCollector): for view_col, view_req in view_reqs.items(): # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col - time_indices = \ - view_req.shift - ( - 1 if data_col in [SampleBatch.OBS, "t", "env_id", - SampleBatch.AGENT_INDEX] else 0) + delta = -1 if data_col in [ + SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX + ] else 0 + # Range of shifts, e.g. "-100:0". Note: This includes index 0! + if view_req.shift_from is not None: + time_indices = (view_req.shift_from + delta, + view_req.shift_to + delta) + # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. + else: + time_indices = view_req.shift + delta data_list = [] + # Loop through agents and add-up their data (batch). for k in keys: if data_col == SampleBatch.EPS_ID: data_list.append(self.agent_collectors[k].episode_id) @@ -482,7 +579,15 @@ class _SimpleListCollector(_SampleCollector): self.agent_collectors[k]._build_buffers({ data_col: fill_value }) - data_list.append(buffers[k][data_col][time_indices]) + if isinstance(time_indices, tuple): + if time_indices[1] == -1: + data_list.append( + buffers[k][data_col][time_indices[0]:]) + else: + data_list.append(buffers[k][data_col][time_indices[ + 0]:time_indices[1] + 1]) + else: + data_list.append(buffers[k][data_col][time_indices]) input_dict[view_col] = np.array(data_list) self._reset_inference_calls(policy_id) diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index a19411433..0cb25d5c7 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -50,8 +50,6 @@ def compute_advantages(rollout: SampleBatch, processed rewards. """ - rollout_size = len(rollout[SampleBatch.ACTIONS]) - assert SampleBatch.VF_PREDS in rollout or not use_critic, \ "use_critic=True but values not found" assert use_critic or not use_gae, \ @@ -90,6 +88,4 @@ def compute_advantages(rollout: SampleBatch, rollout[Postprocessing.ADVANTAGES] = rollout[ Postprocessing.ADVANTAGES].astype(np.float32) - assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \ - "Rollout stacked incorrectly!" return rollout diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py index 49884d9f3..de3f06c29 100644 --- a/rllib/examples/attention_net.py +++ b/rllib/examples/attention_net.py @@ -39,6 +39,7 @@ if __name__ == "__main__": config = { "env": args.env, + # This env_config is only used for the RepeatAfterMeEnv env. "env_config": { "repeat_delay": 2, }, @@ -48,7 +49,7 @@ if __name__ == "__main__": "num_workers": 0, "num_envs_per_worker": 20, "entropy_coeff": 0.001, - "num_sgd_iter": 5, + "num_sgd_iter": 10, "vf_loss_coeff": 1e-5, "model": { "custom_model": GTrXLNet, @@ -56,9 +57,10 @@ if __name__ == "__main__": "custom_model_config": { "num_transformer_units": 1, "attn_dim": 64, - "num_heads": 2, - "memory_tau": 50, + "memory_inference": 100, + "memory_training": 50, "head_dim": 32, + "num_heads": 2, "ff_hidden_dim": 32, }, }, @@ -71,7 +73,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=1) + results = tune.run(args.run, config=config, stop=stop, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/cartpole_lstm.py b/rllib/examples/cartpole_lstm.py index 1c9edc655..de53c6ff1 100644 --- a/rllib/examples/cartpole_lstm.py +++ b/rllib/examples/cartpole_lstm.py @@ -59,7 +59,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=1) + results = tune.run(args.run, config=config, stop=stop, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 38478857c..fc45149a5 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -13,7 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated -from ray.rllib.utils.typing import ModelConfigDict, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, \ + TensorStructType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -238,14 +239,14 @@ class ModelV2: right input dict, state, and seq len arguments. """ - train_batch["is_training"] = is_training + input_dict = train_batch.copy() + input_dict["is_training"] = is_training states = [] i = 0 - while "state_in_{}".format(i) in train_batch: - states.append(train_batch["state_in_{}".format(i)]) + while "state_in_{}".format(i) in input_dict: + states.append(input_dict["state_in_{}".format(i)]) i += 1 - ret = self.__call__(train_batch, states, train_batch.get("seq_lens")) - del train_batch["is_training"] + ret = self.__call__(input_dict, states, input_dict.get("seq_lens")) return ret def import_from_h5(self, h5_file: str) -> None: @@ -316,21 +317,57 @@ class ModelV2: # TODO: (sven) Experimental method. def get_input_dict(self, sample_batch, - index: int = -1) -> Dict[str, TensorType]: - if index < 0: - index = sample_batch.count - 1 + index: Union[int, str] = "last") -> ModelInputDict: + """Creates single ts input-dict at given index from a SampleBatch. + + Args: + sample_batch (SampleBatch): A single-trajectory SampleBatch object + to generate the compute_actions input dict from. + index (Union[int, str]): An integer index value indicating the + position in the trajectory for which to generate the + compute_actions input dict. Set to "last" to generate the dict + at the very end of the trajectory (e.g. for value estimation). + Note that "last" is different from -1, as "last" will use the + final NEXT_OBS as observation input. + + Returns: + ModelInputDict: The (single-timestep) input dict for ModelV2 calls. + """ + last_mappings = { + SampleBatch.OBS: SampleBatch.NEXT_OBS, + SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, + SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, + } input_dict = {} for view_col, view_req in self.inference_view_requirements.items(): # Create batches of size 1 (single-agent input-dict). - - # Index range. - if isinstance(index, tuple): - data = sample_batch[view_col][index[0]:index[1] + 1] - input_dict[view_col] = np.array([data]) - # Single index. + data_col = view_req.data_col or view_col + if index == "last": + data_col = last_mappings.get(data_col, data_col) + if view_req.shift_from is not None: + data = sample_batch[view_col][-1] + traj_len = len(sample_batch[data_col]) + missing_at_end = traj_len % view_req.batch_repeat_value + input_dict[view_col] = np.array([ + np.concatenate([ + data, sample_batch[data_col][-missing_at_end:] + ])[view_req.shift_from:view_req.shift_to + + 1 if view_req.shift_to != -1 else None] + ]) + else: + data = sample_batch[data_col][-1] + input_dict[view_col] = np.array([data]) else: - input_dict[view_col] = sample_batch[view_col][index:index + 1] + # Index range. + if isinstance(index, tuple): + data = sample_batch[data_col][index[0]:index[1] + 1 + if index[1] != -1 else None] + input_dict[view_col] = np.array([data]) + # Single index. + else: + input_dict[view_col] = sample_batch[data_col][ + index:index + 1 if index != -1 else None] # Add valid `seq_lens`, just in case RNNs need it. input_dict["seq_lens"] = np.array([1], dtype=np.int32) diff --git a/rllib/models/tests/test_attention_nets.py b/rllib/models/tests/test_attention_nets.py deleted file mode 100644 index ac6ec134d..000000000 --- a/rllib/models/tests/test_attention_nets.py +++ /dev/null @@ -1,263 +0,0 @@ -import gym -import numpy as np -import unittest - -from ray.rllib.models.tf.attention_net import relative_position_embedding, \ - GTrXLNet -from ray.rllib.models.tf.layers import MultiHeadAttention -from ray.rllib.models.torch.attention_net import relative_position_embedding \ - as relative_position_embedding_torch, GTrXLNet as TorchGTrXLNet -from ray.rllib.models.torch.modules.multi_head_attention import \ - MultiHeadAttention as TorchMultiHeadAttention -from ray.rllib.utils.framework import try_import_torch, try_import_tf -from ray.rllib.utils.test_utils import framework_iterator - -torch, nn = try_import_torch() -tf1, tf, tfv = try_import_tf() - - -class TestAttentionNets(unittest.TestCase): - """Tests various torch/modules and tf/layers required for AttentionNet""" - - def train_torch_full_model(self, - model, - inputs, - outputs, - num_epochs=250, - state=None, - seq_lens=None): - """Convenience method that trains a Torch model for num_epochs epochs - and tests whether loss decreased, as expected. - - Args: - model (nn.Module): Torch model to be trained. - inputs (torch.Tensor): Training data - outputs (torch.Tensor): Training labels - num_epochs (int): Number of epochs to train for - state (torch.Tensor): Internal state of module - seq_lens (torch.Tensor): Tensor of sequence lengths - """ - - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) - - # Check that the layer trains correctly - for t in range(num_epochs): - y_pred = model(inputs, state, seq_lens) - loss = criterion(y_pred[0], torch.squeeze(outputs[0])) - - if t % 10 == 1: - print(t, loss.item()) - - # if t == 0: - # init_loss = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # final_loss = loss.item() - - # The final loss has decreased, which tests - # that the model is learning from the training data. - # self.assertLess(final_loss / init_loss, 0.99) - - def train_torch_layer(self, model, inputs, outputs, num_epochs=250): - """Convenience method that trains a Torch model for num_epochs epochs - and tests whether loss decreased, as expected. - - Args: - model (nn.Module): Torch model to be trained. - inputs (torch.Tensor): Training data - outputs (torch.Tensor): Training labels - num_epochs (int): Number of epochs to train for - """ - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) - - # Check that the layer trains correctly - for t in range(num_epochs): - y_pred = model(inputs) - loss = criterion(y_pred, outputs) - - if t == 1: - init_loss = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - final_loss = loss.item() - - # The final loss has decreased by a factor of 2, which tests - # that the model is learning from the training data. - self.assertLess(final_loss / init_loss, 0.5) - - def train_tf_model(self, - model, - inputs, - outputs, - num_epochs=250, - minibatch_size=32): - """Convenience method that trains a Tensorflow model for num_epochs - epochs and tests whether loss decreased, as expected. - - Args: - model (tf.Model): Torch model to be trained. - inputs (np.array): Training data - outputs (np.array): Training labels - num_epochs (int): Number of training epochs - batch_size (int): Number of samples in each minibatch - """ - - # Configure a model for mean-squared error loss. - model.compile(optimizer="SGD", loss="mse", metrics=["mae"]) - - hist = model.fit( - inputs, - outputs, - verbose=0, - epochs=num_epochs, - batch_size=minibatch_size).history - init_loss = hist["loss"][0] - final_loss = hist["loss"][-1] - - self.assertLess(final_loss / init_loss, 0.5) - - def test_multi_head_attention(self): - """Tests the MultiHeadAttention mechanism of Vaswani et al.""" - # B is batch size - B = 1 - # D_in is attention dim, L is memory_tau - L, D_in, D_out = 2, 32, 10 - - for fw, sess in framework_iterator( - frameworks=("tfe", "torch", "tf"), session=True): - # Create a single attention layer with 2 heads. - if fw == "torch": - - # Create random Tensors to hold inputs and outputs - x = torch.randn(B, L, D_in) - y = torch.randn(B, L, D_out) - - model = TorchMultiHeadAttention( - in_dim=D_in, out_dim=D_out, num_heads=2, head_dim=32) - - self.train_torch_layer(model, x, y, num_epochs=500) - - # Framework is tensorflow or tensorflow-eager. - else: - x = np.random.random((B, L, D_in)) - y = np.random.random((B, L, D_out)) - - inputs = tf.keras.layers.Input(shape=(L, D_in)) - - model = tf.keras.Sequential([ - inputs, - MultiHeadAttention( - out_dim=D_out, num_heads=2, head_dim=32) - ]) - self.train_tf_model(model, x, y) - - def test_attention_net(self): - """Tests the GTrXL. - - Builds a full AttentionNet and checks that it trains in a supervised - setting.""" - - # Checks that torch and tf embedding matrices are the same - with tf1.Session().as_default() as sess: - assert np.allclose( - relative_position_embedding(20, 15).eval(session=sess), - relative_position_embedding_torch(20, 15).numpy()) - - # B is batch size - B = 32 - # D_in is attention dim, L is memory_tau - L, D_in, D_out = 2, 16, 2 - - for fw, sess in framework_iterator(session=True): - - # Create a single attention layer with 2 heads - if fw == "torch": - # Create random Tensors to hold inputs and outputs - x = torch.randn(B, L, D_in) - y = torch.randn(B, L, D_out) - - value_labels = torch.randn(B, L, D_in) - memory_labels = torch.randn(B, L, D_out) - - attention_net = TorchGTrXLNet( - observation_space=gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(D_in, )), - action_space=gym.spaces.Discrete(D_out), - num_outputs=D_out, - model_config={"max_seq_len": 2}, - name="TestTorchAttentionNet", - num_transformer_units=2, - attn_dim=D_in, - num_heads=2, - memory_tau=L, - head_dim=D_out, - ff_hidden_dim=16, - init_gate_bias=2.0) - - init_state = attention_net.get_initial_state() - - # Get initial state and add a batch dimension. - init_state = [np.expand_dims(s, 0) for s in init_state] - seq_lens_init = torch.full( - size=(B, ), fill_value=L, dtype=torch.int32) - - # Torch implementation expects a formatted input_dict instead - # of a numpy array as input. - input_dict = {"obs": x} - self.train_torch_full_model( - attention_net, - input_dict, [y, value_labels, memory_labels], - num_epochs=250, - state=init_state, - seq_lens=seq_lens_init) - # Framework is tensorflow or tensorflow-eager. - else: - x = np.random.random((B, L, D_in)) - y = np.random.random((B, L, D_out)) - - value_labels = np.random.random((B, L, 1)) - memory_labels = np.random.random((B, L, D_in)) - - # We need to create (N-1) MLP labels for N transformer units - mlp_labels = np.random.random((B, L, D_in)) - - attention_net = GTrXLNet( - observation_space=gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(D_in, )), - action_space=gym.spaces.Discrete(D_out), - num_outputs=D_out, - model_config={"max_seq_len": 2}, - name="TestTFAttentionNet", - num_transformer_units=2, - attn_dim=D_in, - num_heads=2, - memory_tau=L, - head_dim=D_out, - ff_hidden_dim=16, - init_gate_bias=2.0) - model = attention_net.trxl_model - - # Get initial state and add a batch dimension. - init_state = attention_net.get_initial_state() - init_state = [np.tile(s, (B, 1, 1)) for s in init_state] - - self.train_tf_model( - model, [x] + init_state, - [y, value_labels, memory_labels, mlp_labels], - num_epochs=200, - minibatch_size=B) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 2ddbaf33b..ef49f4610 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -8,14 +8,17 @@ Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019. https://www.aclweb.org/anthology/P19-1285.pdf """ +from gym.spaces import Box import numpy as np import gym -from typing import Optional, Any +from typing import Any, Optional from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ SkipConnection from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import ModelConfigDict, TensorType, List @@ -60,7 +63,7 @@ class TrXLNet(RecurrentNetwork): model_config: ModelConfigDict, name: str, num_transformer_units: int, attn_dim: int, num_heads: int, head_dim: int, ff_hidden_dim: int): - """Initializes a TfXLNet object. + """Initializes a TrXLNet object. Args: num_transformer_units (int): The number of Transformer repeats to @@ -88,8 +91,6 @@ class TrXLNet(RecurrentNetwork): self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim) - inputs = tf.keras.layers.Input( shape=(self.max_seq_len, self.obs_dim), name="inputs") E_out = tf.keras.layers.Dense(attn_dim)(inputs) @@ -100,7 +101,6 @@ class TrXLNet(RecurrentNetwork): out_dim=attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=pos_embedding, input_layernorm=False, output_activation=None), fan_in_layer=None)(E_out) @@ -160,7 +160,8 @@ class GTrXLNet(RecurrentNetwork): >> num_transformer_units=1, >> attn_dim=32, >> num_heads=2, - >> memory_tau=50, + >> memory_inference=100, + >> memory_training=50, >> etc.. >> } """ @@ -174,11 +175,12 @@ class GTrXLNet(RecurrentNetwork): num_transformer_units: int, attn_dim: int, num_heads: int, - memory_tau: int, + memory_inference: int, + memory_training: int, head_dim: int, ff_hidden_dim: int, init_gate_bias: float = 2.0): - """Initializes a GTrXLNet. + """Initializes a GTrXLNet instance. Args: num_transformer_units (int): The number of Transformer repeats to @@ -187,9 +189,15 @@ class GTrXLNet(RecurrentNetwork): unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. - memory_tau (int): The number of timesteps to store in each - transformer block's memory M (concat'd over time and fed into - next transformer block as input). + memory_inference (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as inference + input. The first transformer unit will receive this number of + past observations (plus the current one), instead. + memory_training (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as training + input (plus the actual input sequence of len=max_seq_len). + The first transformer unit will receive this number of + past observations (plus the input sequence), instead. head_dim (int): The dimension of a single(!) head. Denoted as `d` in [3]. ff_hidden_dim (int): The dimension of the hidden layer within @@ -208,21 +216,18 @@ class GTrXLNet(RecurrentNetwork): self.num_transformer_units = num_transformer_units self.attn_dim = attn_dim self.num_heads = num_heads - self.memory_tau = memory_tau + self.memory_inference = memory_inference + self.memory_training = memory_training self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - # Constant (non-trainable) sinusoid rel pos encoding matrix. - Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, - self.attn_dim) - - # Raw observation input. + # Raw observation input (plus (None) time axis). input_layer = tf.keras.layers.Input( - shape=(self.max_seq_len, self.obs_dim), name="inputs") + shape=(None, self.obs_dim), name="inputs") memory_ins = [ tf.keras.layers.Input( - shape=(self.memory_tau, self.attn_dim), + shape=(None, self.attn_dim), dtype=tf.float32, name="memory_in_{}".format(i)) for i in range(self.num_transformer_units) @@ -242,7 +247,6 @@ class GTrXLNet(RecurrentNetwork): out_dim=self.attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=Phi, input_layernorm=True, output_activation=tf.nn.relu), fan_in_layer=GRUGate(init_gate_bias), @@ -280,69 +284,52 @@ class GTrXLNet(RecurrentNetwork): self.register_variables(self.trxl_model.variables) self.trxl_model.summary() - @override(RecurrentNetwork) - def forward_rnn(self, inputs: TensorType, state: List[TensorType], - seq_lens: TensorType) -> (TensorType, List[TensorType]): - # To make Attention work with current RLlib's ModelV2 API: - # We assume `state` is the history of L recent observations (all - # concatenated into one tensor) and append the current inputs to the - # end and only keep the most recent (up to `max_seq_len`). This allows - # us to deal with timestep-wise inference and full sequence training - # within the same logic. - observations = state[0] - memory = state[1:] + # Setup inference view (`memory-inference` x past observations + + # current one (0)) + # 1 to `num_transformer_units`: Memory data (one per transformer unit). + for i in range(self.num_transformer_units): + space = Box(-1.0, 1.0, shape=(self.attn_dim, )) + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift="-{}:-1".format(self.memory_inference), + # Repeat the incoming state every max-seq-len times. + batch_repeat_value=self.max_seq_len, + space=space) + self.inference_view_requirements["state_out_{}".format(i)] = \ + ViewRequirement( + space=space, + used_for_training=False) - observations = tf.concat( - (observations, inputs), axis=1)[:, -self.max_seq_len:] - all_out = self.trxl_model([observations] + memory) - logits, self._value_out = all_out[0], all_out[1] + @override(ModelV2) + def forward(self, input_dict, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + assert seq_lens is not None + + # Add the time dim to observations. + B = tf.shape(seq_lens)[0] + observations = input_dict[SampleBatch.OBS] + + shape = tf.shape(observations) + T = shape[0] // B + observations = tf.reshape(observations, + tf.concat([[-1, T], shape[1:]], axis=0)) + + all_out = self.trxl_model([observations] + state) + + logits = all_out[0] + self._value_out = all_out[1] memory_outs = all_out[2:] - # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. - if self.memory_tau > self.max_seq_len: - memory_outs = [ - tf.concat( - [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], - axis=1) for i, m in enumerate(memory_outs) - ] - else: - memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] - T = tf.shape(inputs)[1] # Length of input segment (time). - logits = logits[:, -T:] - self._value_out = self._value_out[:, -T:] - - return logits, [observations] + memory_outs + return tf.reshape(logits, [-1, self.num_outputs]), [ + tf.reshape(m, [-1, self.attn_dim]) for m in memory_outs + ] # TODO: (sven) Deprecate this once trajectory view API has fully matured. @override(RecurrentNetwork) def get_initial_state(self) -> List[np.ndarray]: - # State is the T last observations concat'd together into one Tensor. - # Plus all Transformer blocks' E(l) outputs concat'd together (up to - # tau timesteps). - return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \ - [np.zeros((self.memory_tau, self.attn_dim), np.float32) - for _ in range(self.num_transformer_units)] + return [] @override(ModelV2) def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) - - -def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType: - """Creates a [seq_length x seq_length] matrix for rel. pos encoding. - - Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding - matrix. - - Args: - seq_length (int): The max. sequence length (time axis). - out_dim (int): The number of nodes to go into the first Tranformer - layer with. - - Returns: - tf.Tensor: The encoding matrix Phi. - """ - inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) - pos_offsets = tf.range(seq_length - 1., -1., -1.) - inputs = pos_offsets[:, None] * inverse_freq[None, :] - return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) diff --git a/rllib/models/tf/layers/__init__.py b/rllib/models/tf/layers/__init__.py index 68ae2ea53..0661aac98 100644 --- a/rllib/models/tf/layers/__init__.py +++ b/rllib/models/tf/layers/__init__.py @@ -1,11 +1,11 @@ from ray.rllib.models.tf.layers.gru_gate import GRUGate from ray.rllib.models.tf.layers.noisy_layer import NoisyLayer from ray.rllib.models.tf.layers.relative_multi_head_attention import \ - RelativeMultiHeadAttention + PositionalEmbedding, RelativeMultiHeadAttention from ray.rllib.models.tf.layers.skip_connection import SkipConnection from ray.rllib.models.tf.layers.multi_head_attention import MultiHeadAttention __all__ = [ - "GRUGate", "MultiHeadAttention", "NoisyLayer", + "GRUGate", "MultiHeadAttention", "NoisyLayer", "PositionalEmbedding", "RelativeMultiHeadAttention", "SkipConnection" ] diff --git a/rllib/models/tf/layers/relative_multi_head_attention.py b/rllib/models/tf/layers/relative_multi_head_attention.py index f7d70ab60..840449e1c 100644 --- a/rllib/models/tf/layers/relative_multi_head_attention.py +++ b/rllib/models/tf/layers/relative_multi_head_attention.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Optional from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import TensorType @@ -16,9 +16,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): out_dim: int, num_heads: int, head_dim: int, - rel_pos_encoder: Any, input_layernorm: bool = False, - output_activation: Optional[Any] = None, + output_activation: Optional["tf.nn.activation"] = None, **kwargs): """Initializes a RelativeMultiHeadAttention keras Layer object. @@ -28,7 +27,6 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): Denoted `H` in [2]. head_dim (int): The dimension of a single(!) attention head Denoted `D` in [2]. - rel_pos_encoder (: input_layernorm (bool): Whether to prepend a LayerNorm before everything else. Should be True for building a GTrXL. output_activation (Optional[tf.nn.activation]): Optional tf.nn @@ -50,9 +48,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): self._uvar = self.add_weight(shape=(num_heads, head_dim)) self._vvar = self.add_weight(shape=(num_heads, head_dim)) + # Constant (non-trainable) sinusoid rel pos encoding matrix, which + # depends on this incoming time dimension. + # For inference, we prepend the memory to the current timestep's + # input: Tau + 1. For training, we prepend the memory to the input + # sequence: Tau + T. + self._pos_embedding = PositionalEmbedding(out_dim) self._pos_proj = tf.keras.layers.Dense( num_heads * head_dim, use_bias=False) - self._rel_pos_encoder = rel_pos_encoder self._input_layernorm = None if input_layernorm: @@ -66,9 +69,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). - Tau = memory.shape.as_list()[1] if memory is not None else 0 - if memory is not None: - inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1) + Tau = tf.shape(memory)[1] + inputs = tf.concat([tf.stop_gradient(memory), inputs], axis=1) # Apply the Layer-Norm. if self._input_layernorm is not None: @@ -77,15 +79,17 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): qkv = self._qkv_layer(inputs) queries, keys, values = tf.split(qkv, 3, -1) - # Cut out Tau memory timesteps from query. + # Cut out memory timesteps from query. queries = queries[:, -T:] + # Splitting up queries into per-head dims (d). queries = tf.reshape(queries, [-1, T, H, d]) - keys = tf.reshape(keys, [-1, T + Tau, H, d]) - values = tf.reshape(values, [-1, T + Tau, H, d]) + keys = tf.reshape(keys, [-1, Tau + T, H, d]) + values = tf.reshape(values, [-1, Tau + T, H, d]) - R = self._pos_proj(self._rel_pos_encoder) - R = tf.reshape(R, [T + Tau, H, d]) + R = self._pos_embedding(Tau + T) + R = self._pos_proj(R) + R = tf.reshape(R, [Tau + T, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) @@ -96,9 +100,9 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): score = score + self.rel_shift(pos_score) score = score / d**0.5 - # causal mask of the same length as the sequence + # Causal mask of the same length as the sequence. mask = tf.sequence_mask( - tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype) + tf.range(Tau + 1, Tau + T + 1), dtype=score.dtype) mask = mask[None, :, :, None] masked_score = score * mask + 1e30 * (mask - 1.) @@ -121,3 +125,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): x = tf.reshape(x, x_size) return x + + +class PositionalEmbedding(tf.keras.layers.Layer if tf else object): + def __init__(self, out_dim, **kwargs): + super().__init__(**kwargs) + self.inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) + + def call(self, seq_length): + pos_offsets = tf.cast(tf.range(seq_length - 1, -1, -1), tf.float32) + inputs = pos_offsets[:, None] * self.inverse_freq[None, :] + return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) diff --git a/rllib/models/tf/layers/skip_connection.py b/rllib/models/tf/layers/skip_connection.py index efb89f2e3..a44ae2bc1 100644 --- a/rllib/models/tf/layers/skip_connection.py +++ b/rllib/models/tf/layers/skip_connection.py @@ -16,7 +16,6 @@ class SkipConnection(tf.keras.layers.Layer if tf else object): def __init__(self, layer: Any, fan_in_layer: Optional[Any] = None, - add_memory: bool = False, **kwargs): """Initializes a SkipConnection keras layer object. diff --git a/rllib/models/torch/modules/skip_connection.py b/rllib/models/torch/modules/skip_connection.py index 126274b1d..8d79b7826 100644 --- a/rllib/models/torch/modules/skip_connection.py +++ b/rllib/models/torch/modules/skip_connection.py @@ -15,7 +15,6 @@ class SkipConnection(nn.Module): def __init__(self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, - add_memory: bool = False, **kwargs): """Initializes a SkipConnection nn Module object. diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 432e384f2..39b31f63b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -183,11 +183,12 @@ class DynamicTFPolicy(TFPolicy): else: if self.config["_use_trajectory_view_api"]: self._state_inputs = [ - tf1.placeholder( - shape=(None, ) + vr.space.shape, dtype=vr.space.dtype) - for k, vr in + get_placeholder( + space=vr.space, + time_axis=not isinstance(vr.shift, int), + ) for k, vr in self.model.inference_view_requirements.items() - if k[:9] == "state_in_" + if k.startswith("state_in_") ] else: self._state_inputs = [ @@ -423,9 +424,14 @@ class DynamicTFPolicy(TFPolicy): input_dict[view_col] = existing_inputs[view_col] # All others. else: + time_axis = not isinstance(view_req.shift, int) if view_req.used_for_training: + # Create a +time-axis placeholder if the shift is not an + # int (range or list of ints). input_dict[view_col] = get_placeholder( - space=view_req.space, name=view_col) + space=view_req.space, + name=view_col, + time_axis=time_axis) dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) @@ -490,10 +496,10 @@ class DynamicTFPolicy(TFPolicy): dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) + dummy_batch = SampleBatch(dummy_batch) - sb = SampleBatch(dummy_batch) - batch_for_postproc = UsageTrackingDict(sb) - batch_for_postproc.count = sb.count + batch_for_postproc = UsageTrackingDict(dummy_batch) + batch_for_postproc.count = dummy_batch.count logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, batch_for_postproc, self._sess) @@ -519,6 +525,7 @@ class DynamicTFPolicy(TFPolicy): train_batch.update({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, }) for k, v in postprocessed_batch.items(): @@ -578,7 +585,8 @@ class DynamicTFPolicy(TFPolicy): for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.inference_view_requirements: - self.view_requirements[key].used_for_training = False + if key in self.view_requirements: + self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 758cfc948..f17d60e06 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -314,12 +314,16 @@ def build_eager_tf_policy(name, self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch) - # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( postprocessed_batch, shuffle=False, max_seq_len=self._max_seq_len, - batch_divisibility_req=self.batch_divisibility_req) + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + postprocessed_batch["is_training"] = True return self._learn_on_batch_eager(postprocessed_batch) @convert_eager_inputs @@ -332,12 +336,14 @@ def build_eager_tf_policy(name, @override(Policy) def compute_gradients(self, samples): - # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( samples, shuffle=False, max_seq_len=self._max_seq_len, batch_divisibility_req=self.batch_divisibility_req) + + self._is_training = True + samples["is_training"] = True return self._compute_gradients_eager(samples) @convert_eager_inputs @@ -369,7 +375,7 @@ def build_eager_tf_policy(name, # TODO: remove python side effect to cull sources of bugs. self._is_training = False - self._state_in = state_batches + self._state_in = state_batches or [] if not tf1.executing_eagerly(): tf1.enable_eager_execution() @@ -591,8 +597,6 @@ def build_eager_tf_policy(name, def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" - self._is_training = True - with tf.GradientTape(persistent=gradients_fn is not None) as tape: loss = loss_fn(self, self.model, self.dist_class, samples) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index a1e92ac37..4695e366f 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -629,10 +629,9 @@ class Policy(metaclass=ABCMeta): batch_for_postproc.count = self._dummy_batch.count self.exploration.postprocess_trajectory(self, batch_for_postproc) postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) + seq_lens = None if state_outs: B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size] - # TODO: (sven) This hack will not work for attention net traj. - # view setup. i = 0 while "state_in_{}".format(i) in postprocessed_batch: postprocessed_batch["state_in_{}".format(i)] = \ @@ -642,12 +641,11 @@ class Policy(metaclass=ABCMeta): postprocessed_batch["state_out_{}".format(i)][:B] i += 1 seq_len = sample_batch_size // B - postprocessed_batch["seq_lens"] = \ - np.array([seq_len for _ in range(B)], dtype=np.int32) - # Remove the UsageTrackingDict wrap to prep for wrapping the - # train batch with a to-tensor UsageTrackingDict. - train_batch = {k: v for k, v in postprocessed_batch.items()} - train_batch = self._lazy_tensor_dict(train_batch) + seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32) + # Wrap `train_batch` with a to-tensor UsageTrackingDict. + train_batch = self._lazy_tensor_dict(postprocessed_batch) + if seq_lens is not None: + train_batch["seq_lens"] = seq_lens train_batch.count = self._dummy_batch.count # Call the loss function, if it exists. if self._loss is not None: @@ -712,13 +710,33 @@ class Policy(metaclass=ABCMeta): ret[view_col] = \ np.zeros((batch_size, ) + shape[1:], np.float32) else: - if isinstance(view_req.space, gym.spaces.Space): - ret[view_col] = np.zeros_like( - [view_req.space.sample() for _ in range(batch_size)]) + # Range of indices on time-axis, e.g. "-50:-1". + if view_req.shift_from is not None: + ret[view_col] = np.zeros_like([[ + view_req.space.sample() + for _ in range(view_req.shift_to - + view_req.shift_from + 1) + ] for _ in range(batch_size)]) + # Set of (probably non-consecutive) indices. + elif isinstance(view_req.shift, (list, tuple)): + ret[view_col] = np.zeros_like([[ + view_req.space.sample() + for t in range(len(view_req.shift)) + ] for _ in range(batch_size)]) + # Single shift int value. else: - ret[view_col] = [view_req.space for _ in range(batch_size)] + if isinstance(view_req.space, gym.spaces.Space): + ret[view_col] = np.zeros_like([ + view_req.space.sample() for _ in range(batch_size) + ]) + else: + ret[view_col] = [ + view_req.space for _ in range(batch_size) + ] - return SampleBatch(ret) + # Due to different view requirements for the different columns, + # columns in the resulting batch may not all have the same batch size. + return SampleBatch(ret, _dont_check_lens=True) def _update_model_inference_view_requirements_from_init_state(self): """Uses Model's (or this Policy's) init state to add needed ViewReqs. @@ -737,8 +755,13 @@ class Policy(metaclass=ABCMeta): view_reqs = model.inference_view_requirements if model else \ self.view_requirements view_reqs["state_in_{}".format(i)] = ViewRequirement( - "state_out_{}".format(i), shift=-1, space=space) - view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space) + "state_out_{}".format(i), + shift=-1, + batch_repeat_value=self.config.get("model", {}).get( + "max_seq_len", 1), + space=space) + view_reqs["state_out_{}".format(i)] = ViewRequirement( + space=space, used_for_training=True) def clip_action(action, action_space): diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 486bbf0db..1cf3fc4aa 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -19,7 +19,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.typing import TensorType, ViewRequirementsDict from ray.util import log_once tf1, tf, tfv = try_import_tf() @@ -35,6 +35,7 @@ def pad_batch_to_sequences_of_same_size( shuffle: bool = False, batch_divisibility_req: int = 1, feature_keys: Optional[List[str]] = None, + view_requirements: Optional[ViewRequirementsDict] = None, ): """Applies padding to `batch` so it's choppable into same-size sequences. @@ -55,6 +56,9 @@ def pad_batch_to_sequences_of_same_size( feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. + view_requirements (Optional[ViewRequirementsDict]): An optional + Policy ViewRequirements dict to be able to infer whether + e.g. dynamic max'ing should be applied over the seq_lens. """ if batch_divisibility_req > 1: meets_divisibility_reqs = ( @@ -64,46 +68,65 @@ def pad_batch_to_sequences_of_same_size( else: meets_divisibility_reqs = True - # RNN-case. + states_already_reduced_to_init = False + + # RNN/attention net case. Figure out whether we should apply dynamic + # max'ing over the list of sequence lengths. if "state_in_0" in batch or "state_out_0" in batch: - dynamic_max = True + # Check, whether the state inputs have already been reduced to their + # init values at the beginning of each max_seq_len chunk. + if batch.seq_lens is not None and \ + len(batch["state_in_0"]) == len(batch.seq_lens): + states_already_reduced_to_init = True + + # RNN (or single timestep state-in): Set the max dynamically. + if view_requirements["state_in_0"].shift_from is None: + dynamic_max = True + # Attention Nets (state inputs are over some range): No dynamic maxing + # possible. + else: + dynamic_max = False # Multi-agent case. elif not meets_divisibility_reqs: max_seq_len = batch_divisibility_req dynamic_max = False - # Simple case: not RNN nor do we need to pad. + # Simple case: No RNN/attention net, nor do we need to pad. else: if shuffle: batch.shuffle() return - # RNN or multi-agent case. + # RNN, attention net, or multi-agent case. state_keys = [] feature_keys_ = feature_keys or [] - for k in batch.keys(): - if "state_in_" in k: + for k, v in batch.items(): + if k.startswith("state_in_"): state_keys.append(k) - elif not feature_keys and "state_out_" not in k and k != "infos": + elif not feature_keys and not k.startswith("state_out_") and \ + k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray): feature_keys_.append(k) feature_sequences, initial_states, seq_lens = \ chop_into_sequences( - batch[SampleBatch.EPS_ID], - batch[SampleBatch.UNROLL_ID], - batch[SampleBatch.AGENT_INDEX], - [batch[k] for k in feature_keys_], - [batch[k] for k in state_keys], - max_seq_len, + feature_columns=[batch[k] for k in feature_keys_], + state_columns=[batch[k] for k in state_keys], + episode_ids=batch.get(SampleBatch.EPS_ID), + unroll_ids=batch.get(SampleBatch.UNROLL_ID), + agent_indices=batch.get(SampleBatch.AGENT_INDEX), + seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")), + max_seq_len=max_seq_len, dynamic_max=dynamic_max, + states_already_reduced_to_init=states_already_reduced_to_init, shuffle=shuffle) + for i, k in enumerate(feature_keys_): batch[k] = feature_sequences[i] for i, k in enumerate(state_keys): batch[k] = initial_states[i] - batch["seq_lens"] = seq_lens + batch["seq_lens"] = np.array(seq_lens) if log_once("rnn_ma_feed_dict"): - logger.info("Padded input for RNN:\n\n{}\n".format( + logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, @@ -157,18 +180,18 @@ def add_time_dimension(padded_inputs: TensorType, return torch.reshape(padded_inputs, new_shape) -# NOTE: This function will be deprecated once chunks already come padded and -# correctly chopped from the _SampleCollector object (in time-major fashion -# or not). It is already no longer user iff `_use_trajectory_view_api` = True. @DeveloperAPI -def chop_into_sequences(episode_ids, - unroll_ids, - agent_indices, +def chop_into_sequences(*, feature_columns, state_columns, max_seq_len, + episode_ids=None, + unroll_ids=None, + agent_indices=None, dynamic_max=True, shuffle=False, + seq_lens=None, + states_already_reduced_to_init=False, _extra_padding=0): """Truncate and pad experiences into fixed-length sequences. @@ -212,23 +235,24 @@ def chop_into_sequences(episode_ids, [2, 3, 1] """ - prev_id = None - seq_lens = [] - seq_len = 0 - unique_ids = np.add( - np.add(episode_ids, agent_indices), - np.array(unroll_ids, dtype=np.int64) << 32) - for uid in unique_ids: - if (prev_id is not None and uid != prev_id) or \ - seq_len >= max_seq_len: + if seq_lens is None or len(seq_lens) == 0: + prev_id = None + seq_lens = [] + seq_len = 0 + unique_ids = np.add( + np.add(episode_ids, agent_indices), + np.array(unroll_ids, dtype=np.int64) << 32) + for uid in unique_ids: + if (prev_id is not None and uid != prev_id) or \ + seq_len >= max_seq_len: + seq_lens.append(seq_len) + seq_len = 0 + seq_len += 1 + prev_id = uid + if seq_len: seq_lens.append(seq_len) - seq_len = 0 - seq_len += 1 - prev_id = uid - if seq_len: - seq_lens.append(seq_len) - assert sum(seq_lens) == len(unique_ids) - seq_lens = np.array(seq_lens, dtype=np.int32) + seq_lens = np.array(seq_lens, dtype=np.int32) + assert sum(seq_lens) == len(feature_columns[0]) # Dynamically shrink max len as needed to optimize memory usage if dynamic_max: @@ -252,18 +276,23 @@ def chop_into_sequences(episode_ids, f_pad[seq_base + seq_offset] = f[i] i += 1 seq_base += max_seq_len - assert i == len(unique_ids), f + assert i == len(f), f feature_sequences.append(f_pad) - initial_states = [] - for s in state_columns: - s = np.array(s) - s_init = [] - i = 0 - for len_ in seq_lens: - s_init.append(s[i]) - i += len_ - initial_states.append(np.array(s_init)) + if states_already_reduced_to_init: + initial_states = state_columns + else: + initial_states = [] + for s in state_columns: + # Skip unnecessary copy. + if not isinstance(s, np.ndarray): + s = np.array(s) + s_init = [] + i = 0 + for len_ in seq_lens: + s_init.append(s[i]) + i += len_ + initial_states.append(np.array(s_init)) if shuffle: permutation = np.random.permutation(len(seq_lens)) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a2934fdb9..a1b4c43bc 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -61,6 +61,7 @@ class SampleBatch: # Possible seq_lens (TxB or BxT) setup. self.time_major = kwargs.pop("_time_major", None) self.seq_lens = kwargs.pop("_seq_lens", None) + self.dont_check_lens = kwargs.pop("_dont_check_lens", False) self.max_seq_len = None if self.seq_lens is not None and len(self.seq_lens) > 0: self.max_seq_len = max(self.seq_lens) @@ -76,8 +77,10 @@ class SampleBatch: self.data[k] = np.array(v) if not lengths: raise ValueError("Empty sample batch") - assert len(set(lengths)) == 1, \ - "Data columns must be same length, but lens are {}".format(lengths) + if not self.dont_check_lens: + assert len(set(lengths)) == 1, \ + "Data columns must be same length, but lens are " \ + "{}".format(lengths) if self.seq_lens is not None and len(self.seq_lens) > 0: self.count = sum(self.seq_lens) else: @@ -117,7 +120,8 @@ class SampleBatch: return SampleBatch( out, _seq_lens=np.array(seq_lens, dtype=np.int32), - _time_major=concat_samples[0].time_major) + _time_major=concat_samples[0].time_major, + _dont_check_lens=True) @PublicAPI def concat(self, other: "SampleBatch") -> "SampleBatch": @@ -248,12 +252,35 @@ class SampleBatch: SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ - if self.time_major is not None: + if self.seq_lens is not None and len(self.seq_lens) > 0: + data = {k: v[start:end] for k, v in self.data.items()} + # Fix state_in_x data. + count = 0 + state_start = None + seq_lens = None + for i, seq_len in enumerate(self.seq_lens): + count += seq_len + if count >= end: + state_idx = 0 + state_key = "state_in_{}".format(state_idx) + while state_key in self.data: + data[state_key] = self.data[state_key][state_start:i + + 1] + state_idx += 1 + state_key = "state_in_{}".format(state_idx) + seq_lens = list(self.seq_lens[state_start:i]) + [ + seq_len - (count - end) + ] + assert sum(seq_lens) == (end - start) + break + elif state_start is None and count > start: + state_start = i + return SampleBatch( - {k: v[:, start:end] - for k, v in self.data.items()}, - _seq_lens=self.seq_lens[start:end], - _time_major=self.time_major) + data, + _seq_lens=np.array(seq_lens, dtype=np.int32), + _time_major=self.time_major, + _dont_check_lens=True) else: return SampleBatch( {k: v[start:end] diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index f6e48dad2..fe6ec900b 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -174,11 +174,6 @@ class TFPolicy(Policy): raise ValueError( "Number of state input and output tensors must match, got: " "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) if self._state_inputs and self._seq_lens is None: raise ValueError( "seq_lens tensor must be given if state inputs are defined") @@ -263,6 +258,11 @@ class TFPolicy(Policy): (name, tf1.placeholders) needed for calculating the loss. """ self._loss_input_dict = dict(loss_inputs) + self._loss_input_dict_no_rnn = { + k: v + for k, v in self._loss_input_dict.items() + if (v not in self._state_inputs and v != self._seq_lens) + } for i, ph in enumerate(self._state_inputs): self._loss_input_dict["state_in_{}".format(i)] = ph @@ -791,11 +791,11 @@ class TFPolicy(Policy): **fetches[LEARNER_STATS_KEY]) return fetches - def _get_loss_inputs_dict(self, batch, shuffle): + def _get_loss_inputs_dict(self, train_batch, shuffle): """Return a feed dict from a batch. Args: - batch (SampleBatch): batch of data to derive inputs from + train_batch (SampleBatch): batch of data to derive inputs from. shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. @@ -806,28 +806,30 @@ class TFPolicy(Policy): # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( - batch, + train_batch, shuffle=shuffle, max_seq_len=self._max_seq_len, batch_divisibility_req=self._batch_divisibility_req, - feature_keys=[ - k for k in self._loss_input_dict.keys() if k != "seq_lens" - ], + feature_keys=list(self._loss_input_dict_no_rnn.keys()), + view_requirements=self.view_requirements, ) - batch["is_training"] = True + + # Mark the batch as "is_training" so the Model can use this + # information. + train_batch["is_training"] = True # Build the feed dict from the batch. feed_dict = {} for key, placeholder in self._loss_input_dict.items(): - feed_dict[placeholder] = batch[key] + feed_dict[placeholder] = train_batch[key] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] for key in state_keys: - feed_dict[self._loss_input_dict[key]] = batch[key] + feed_dict[self._loss_input_dict[key]] = train_batch[key] if state_keys: - feed_dict[self._seq_lens] = batch["seq_lens"] + feed_dict[self._seq_lens] = train_batch["seq_lens"] return feed_dict diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f294b510d..c27a7603d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -345,12 +345,13 @@ class TorchPolicy(Policy): @DeveloperAPI def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: - # Get batch ready for RNNs, if applicable. + pad_batch_to_sequences_of_same_size( postprocessed_batch, max_seq_len=self.max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, ) train_batch = self._lazy_tensor_dict(postprocessed_batch) diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index f9c7750d4..25a5e908a 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -1,4 +1,5 @@ import gym +import numpy as np from typing import List, Optional, Union from ray.rllib.utils.framework import try_import_torch @@ -29,8 +30,9 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0, + shift: Union[int, str, List[int]] = 0, index: Optional[int] = None, + batch_repeat_value: int = 1, used_for_training: bool = True): """Initializes a ViewRequirement object. @@ -64,7 +66,19 @@ class ViewRequirement: self.space = space if space is not None else gym.spaces.Box( float("-inf"), float("inf"), shape=()) - self.index = index - self.shift = shift + if isinstance(self.shift, (list, tuple)): + self.shift = np.array(self.shift) + + # Special case: Providing a (probably larger) range of indices, e.g. + # "-100:0" (past 100 timesteps plus current one). + self.shift_from = self.shift_to = None + if isinstance(self.shift, str): + f, t = self.shift.split(":") + self.shift_from = int(f) + self.shift_to = int(t) + + self.index = index + self.batch_repeat_value = batch_repeat_value + self.used_for_training = used_for_training diff --git a/rllib/tests/test_attention_net_learning.py b/rllib/tests/test_attention_net_learning.py index b060651d6..35e5b3b08 100644 --- a/rllib/tests/test_attention_net_learning.py +++ b/rllib/tests/test_attention_net_learning.py @@ -44,7 +44,8 @@ class TestAttentionNetLearning(unittest.TestCase): "num_transformer_units": 1, "attn_dim": 32, "num_heads": 1, - "memory_tau": 5, + "memory_inference": 5, + "memory_training": 5, "head_dim": 32, "ff_hidden_dim": 32, }, @@ -71,7 +72,8 @@ class TestAttentionNetLearning(unittest.TestCase): # "num_transformer_units": 1, # "attn_dim": 64, # "num_heads": 1, - # "memory_tau": 10, + # "memory_inference": 10, + # "memory_training": 10, # "head_dim": 32, # "ff_hidden_dim": 32, # }, diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index 2685fa942..09b7aef73 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -18,9 +18,13 @@ class TestLSTMUtils(unittest.TestCase): f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, - np.ones_like(eps_ids), - agent_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], @@ -35,9 +39,13 @@ class TestLSTMUtils(unittest.TestCase): obs = np.ones((84, 84, 4)) f = [[obs, obs * 2, obs * 3]] s = [[209, 208, 207]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, - np.ones_like(eps_ids), - agent_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [ np.array([obs, obs * 2, obs * 3]).tolist(), ]) @@ -51,8 +59,13 @@ class TestLSTMUtils(unittest.TestCase): f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - _, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f, - s, 4) + _, _, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=batch_ids, + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2]) def test_multi_agent(self): @@ -62,12 +75,12 @@ class TestLSTMUtils(unittest.TestCase): [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences( - eps_ids, - np.ones_like(eps_ids), - agent_ids, - f, - s, - 4, + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4, dynamic_max=False) self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1]) self.assertEqual(len(f_pad[0]), 20) @@ -78,9 +91,13 @@ class TestLSTMUtils(unittest.TestCase): agent_ids = [2, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, - np.ones_like(eps_ids), - agent_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2]) diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index d5576e0fa..b5b72d44d 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -72,18 +72,23 @@ def minibatches(samples, sgd_minibatch_size): i = 0 slices = [] - if samples.seq_lens: - seq_no = 0 - while i < samples.count: - seq_no_end = seq_no - actual_count = 0 - while actual_count < sgd_minibatch_size and len( - samples.seq_lens) > seq_no_end: - actual_count += samples.seq_lens[seq_no_end] - seq_no_end += 1 - slices.append((seq_no, seq_no_end)) - i += actual_count - seq_no = seq_no_end + if samples.seq_lens is not None and len(samples.seq_lens) > 0: + start_pos = 0 + minibatch_size = 0 + idx = 0 + while idx < len(samples.seq_lens): + seq_len = samples.seq_lens[idx] + minibatch_size += seq_len + # Complete minibatch -> Append to slices. + if minibatch_size >= sgd_minibatch_size: + slices.append((start_pos, start_pos + sgd_minibatch_size)) + start_pos += sgd_minibatch_size + if minibatch_size > sgd_minibatch_size: + overhead = minibatch_size - sgd_minibatch_size + start_pos -= (seq_len - overhead) + idx -= 1 + minibatch_size = 0 + idx += 1 else: while i < samples.count: slices.append((i, i + sgd_minibatch_size)) diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 592f0424d..3010366fc 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -100,6 +100,9 @@ ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]] # Type of dict returned by get_weights() representing model weights. ModelWeights = dict +# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls. +ModelInputDict = Dict[str, TensorType] + # Some kind of sample batch. SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]