mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Attention Nets: tf (#12753)
This commit is contained in:
parent
e715ade2d1
commit
b2bcab711d
26 changed files with 567 additions and 561 deletions
|
@ -1074,13 +1074,6 @@ py_test(
|
|||
# Tag: models
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_attention_nets",
|
||||
tags = ["models"],
|
||||
size = "small",
|
||||
srcs = ["models/tests/test_attention_nets.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_convtranspose2d_stack",
|
||||
tags = ["models"],
|
||||
|
|
|
@ -191,8 +191,8 @@ class DefaultCallbacks:
|
|||
**kwargs) -> None:
|
||||
"""Called at the beginning of Policy.learn_on_batch().
|
||||
|
||||
Note: This is called before the Model's `preprocess_train_batch()`
|
||||
is called.
|
||||
Note: This is called before 0-padding via
|
||||
`pad_batch_to_sequences_of_same_size`.
|
||||
|
||||
Args:
|
||||
policy (Policy): Reference to the current Policy object.
|
||||
|
|
|
@ -198,7 +198,8 @@ def postprocess_ppo_gae(
|
|||
# input_dict.
|
||||
if policy.config["_use_trajectory_view_api"]:
|
||||
# Create an input dict according to the Model's requirements.
|
||||
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
|
||||
input_dict = policy.model.get_input_dict(
|
||||
sample_batch, index="last")
|
||||
last_r = policy._value(**input_dict)
|
||||
# TODO: (sven) Remove once trajectory view API is all-algo default.
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import collections
|
||||
from gym.spaces import Space
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
|
@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray:
|
|||
return arr
|
||||
|
||||
|
||||
_INIT_COLS = [SampleBatch.OBS]
|
||||
|
||||
|
||||
class _AgentCollector:
|
||||
"""Collects samples for one agent in one trajectory (episode).
|
||||
|
||||
|
@ -45,9 +49,18 @@ class _AgentCollector:
|
|||
|
||||
_next_unroll_id = 0 # disambiguates unrolls within a single episode
|
||||
|
||||
def __init__(self, shift_before: int = 0):
|
||||
self.shift_before = max(shift_before, 1)
|
||||
def __init__(self, view_reqs):
|
||||
# Determine the size of the buffer we need for data before the actual
|
||||
# episode starts. This is used for 0-buffering of e.g. prev-actions,
|
||||
# or internal state inputs.
|
||||
self.shift_before = -min(
|
||||
(int(vr.shift.split(":")[0])
|
||||
if isinstance(vr.shift, str) else vr.shift) +
|
||||
(-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0)
|
||||
for k, vr in view_reqs.items())
|
||||
# The actual data buffers (lists holding each timestep's data).
|
||||
self.buffers: Dict[str, List] = {}
|
||||
# The episode ID for the agent for which we collect data.
|
||||
self.episode_id = None
|
||||
# The simple timestep count for this agent. Gets increased by one
|
||||
# each time a (non-initial!) observation is added.
|
||||
|
@ -137,31 +150,88 @@ class _AgentCollector:
|
|||
# -> skip.
|
||||
if data_col not in self.buffers:
|
||||
continue
|
||||
|
||||
# OBS are already shifted by -1 (the initial obs starts one ts
|
||||
# before all other data columns).
|
||||
shift = view_req.shift - \
|
||||
(1 if data_col == SampleBatch.OBS else 0)
|
||||
obs_shift = -1 if data_col == SampleBatch.OBS else 0
|
||||
|
||||
# Keep an np-array cache so we don't have to regenerate the
|
||||
# np-array for different view_cols using to the same data_col.
|
||||
if data_col not in np_data:
|
||||
np_data[data_col] = to_float_np_array(self.buffers[data_col])
|
||||
# Shift is exactly 0: Send trajectory as is.
|
||||
if shift == 0:
|
||||
data = np_data[data_col][self.shift_before:]
|
||||
# Shift is positive: We still need to 0-pad at the end here.
|
||||
elif shift > 0:
|
||||
data = to_float_np_array(
|
||||
self.buffers[data_col][self.shift_before + shift:] + [
|
||||
np.zeros(
|
||||
shape=view_req.space.shape,
|
||||
dtype=view_req.space.dtype) for _ in range(shift)
|
||||
|
||||
# Range of indices on time-axis, e.g. "-50:-1". Together with
|
||||
# the `batch_repeat_value`, this determines the data produced.
|
||||
# Example:
|
||||
# batch_repeat_value=10, shift_from=-3, shift_to=-1
|
||||
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
# resulting data=[[-3, -2, -1], [7, 8, 9]]
|
||||
# Range of 3 consecutive items repeats every 10 timesteps.
|
||||
if view_req.shift_from is not None:
|
||||
if view_req.batch_repeat_value > 1:
|
||||
count = int(
|
||||
math.ceil((len(np_data[data_col]) - self.shift_before)
|
||||
/ view_req.batch_repeat_value))
|
||||
data = np.asarray([
|
||||
np_data[data_col][self.shift_before +
|
||||
(i * view_req.batch_repeat_value) +
|
||||
view_req.shift_from +
|
||||
obs_shift:self.shift_before +
|
||||
(i * view_req.batch_repeat_value) +
|
||||
view_req.shift_to + 1 + obs_shift]
|
||||
for i in range(count)
|
||||
])
|
||||
# Shift is negative: Shift into the already existing and 0-padded
|
||||
# "before" area of our buffers.
|
||||
else:
|
||||
data = np_data[data_col][self.shift_before +
|
||||
view_req.shift_from +
|
||||
obs_shift:self.shift_before +
|
||||
view_req.shift_to + 1 + obs_shift]
|
||||
# Set of (probably non-consecutive) indices.
|
||||
# Example:
|
||||
# shift=[-3, 0]
|
||||
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
# resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
|
||||
elif isinstance(view_req.shift, np.ndarray):
|
||||
data = np_data[data_col][self.shift_before + obs_shift +
|
||||
view_req.shift]
|
||||
# Single shift int value. Use the trajectory as-is, and if
|
||||
# `shift` != 0: shifted by that value.
|
||||
else:
|
||||
data = np_data[data_col][self.shift_before + shift:shift]
|
||||
shift = view_req.shift + obs_shift
|
||||
|
||||
# Batch repeat (only provide a value every n timesteps).
|
||||
if view_req.batch_repeat_value > 1:
|
||||
count = int(
|
||||
math.ceil((len(np_data[data_col]) - self.shift_before)
|
||||
/ view_req.batch_repeat_value))
|
||||
data = np.asarray([
|
||||
np_data[data_col][self.shift_before + (
|
||||
i * view_req.batch_repeat_value) + shift]
|
||||
for i in range(count)
|
||||
])
|
||||
# Shift is exactly 0: Use trajectory as is.
|
||||
elif shift == 0:
|
||||
data = np_data[data_col][self.shift_before:]
|
||||
# Shift is positive: We still need to 0-pad at the end.
|
||||
elif shift > 0:
|
||||
data = to_float_np_array(
|
||||
self.buffers[data_col][self.shift_before + shift:] + [
|
||||
np.zeros(
|
||||
shape=view_req.space.shape,
|
||||
dtype=view_req.space.dtype)
|
||||
for _ in range(shift)
|
||||
])
|
||||
# Shift is negative: Shift into the already existing and
|
||||
# 0-padded "before" area of our buffers.
|
||||
else:
|
||||
data = np_data[data_col][self.shift_before + shift:shift]
|
||||
|
||||
if len(data) > 0:
|
||||
batch_data[view_col] = data
|
||||
|
||||
batch = SampleBatch(batch_data)
|
||||
# Due to possible batch-repeats > 1, columns in the resulting batch
|
||||
# may not all have the same batch size.
|
||||
batch = SampleBatch(batch_data, _dont_check_lens=True)
|
||||
|
||||
# Add EPS_ID and UNROLL_ID to batch.
|
||||
batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
|
||||
|
@ -230,15 +300,22 @@ class _PolicyCollector:
|
|||
appended to this policy's buffers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a _PolicyCollector instance."""
|
||||
def __init__(self, policy):
|
||||
"""Initializes a _PolicyCollector instance.
|
||||
|
||||
Args:
|
||||
policy (Policy): The policy object.
|
||||
"""
|
||||
|
||||
self.buffers: Dict[str, List] = collections.defaultdict(list)
|
||||
self.policy = policy
|
||||
# The total timestep count for all agents that use this policy.
|
||||
# NOTE: This is not an env-step count (across n agents). AgentA and
|
||||
# agentB, both using this policy, acting in the same episode and both
|
||||
# doing n steps would increase the count by 2*n.
|
||||
self.agent_steps = 0
|
||||
# Seq-lens list of already added agent batches.
|
||||
self.seq_lens = [] if policy.is_recurrent() else None
|
||||
|
||||
def add_postprocessed_batch_for_training(
|
||||
self, batch: SampleBatch,
|
||||
|
@ -257,11 +334,18 @@ class _PolicyCollector:
|
|||
# 1) If col is not in view_requirements, we must have a direct
|
||||
# child of the base Policy that doesn't do auto-view req creation.
|
||||
# 2) Col is in view-reqs and needed for training.
|
||||
if view_col not in view_requirements or \
|
||||
view_requirements[view_col].used_for_training:
|
||||
view_req = view_requirements.get(view_col)
|
||||
if view_req is None or view_req.used_for_training:
|
||||
self.buffers[view_col].extend(data)
|
||||
# Add the agent's trajectory length to our count.
|
||||
self.agent_steps += batch.count
|
||||
# Adjust the seq-lens array depending on the incoming agent sequences.
|
||||
if self.seq_lens is not None:
|
||||
max_seq_len = self.policy.config["model"]["max_seq_len"]
|
||||
count = batch.count
|
||||
while count > 0:
|
||||
self.seq_lens.append(min(count, max_seq_len))
|
||||
count -= max_seq_len
|
||||
|
||||
def build(self):
|
||||
"""Builds a SampleBatch for this policy from the collected data.
|
||||
|
@ -273,20 +357,22 @@ class _PolicyCollector:
|
|||
this policy.
|
||||
"""
|
||||
# Create batch from our buffers.
|
||||
batch = SampleBatch(self.buffers)
|
||||
assert SampleBatch.UNROLL_ID in batch.data
|
||||
batch = SampleBatch(
|
||||
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
|
||||
# Clear buffers for future samples.
|
||||
self.buffers.clear()
|
||||
# Reset agent steps to 0.
|
||||
# Reset agent steps to 0 and seq-lens to empty list.
|
||||
self.agent_steps = 0
|
||||
if self.seq_lens is not None:
|
||||
self.seq_lens = []
|
||||
return batch
|
||||
|
||||
|
||||
class _PolicyCollectorGroup:
|
||||
def __init__(self, policy_map):
|
||||
self.policy_collectors = {
|
||||
pid: _PolicyCollector()
|
||||
for pid in policy_map.keys()
|
||||
pid: _PolicyCollector(policy)
|
||||
for pid, policy in policy_map.items()
|
||||
}
|
||||
# Total env-steps (1 env-step=up to N agents stepped).
|
||||
self.env_steps = 0
|
||||
|
@ -396,11 +482,14 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.agent_key_to_policy_id[agent_key] = policy_id
|
||||
else:
|
||||
assert self.agent_key_to_policy_id[agent_key] == policy_id
|
||||
policy = self.policy_map[policy_id]
|
||||
view_reqs = policy.model.inference_view_requirements if \
|
||||
getattr(policy, "model", None) else policy.view_requirements
|
||||
|
||||
# Add initial obs to Trajectory.
|
||||
assert agent_key not in self.agent_collectors
|
||||
# TODO: determine exact shift-before based on the view-req shifts.
|
||||
self.agent_collectors[agent_key] = _AgentCollector()
|
||||
self.agent_collectors[agent_key] = _AgentCollector(view_reqs)
|
||||
self.agent_collectors[agent_key].add_init_obs(
|
||||
episode_id=episode.episode_id,
|
||||
agent_index=episode._agent_index(agent_id),
|
||||
|
@ -466,11 +555,19 @@ class _SimpleListCollector(_SampleCollector):
|
|||
for view_col, view_req in view_reqs.items():
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = view_req.data_col or view_col
|
||||
time_indices = \
|
||||
view_req.shift - (
|
||||
1 if data_col in [SampleBatch.OBS, "t", "env_id",
|
||||
SampleBatch.AGENT_INDEX] else 0)
|
||||
delta = -1 if data_col in [
|
||||
SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID,
|
||||
SampleBatch.AGENT_INDEX
|
||||
] else 0
|
||||
# Range of shifts, e.g. "-100:0". Note: This includes index 0!
|
||||
if view_req.shift_from is not None:
|
||||
time_indices = (view_req.shift_from + delta,
|
||||
view_req.shift_to + delta)
|
||||
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
|
||||
else:
|
||||
time_indices = view_req.shift + delta
|
||||
data_list = []
|
||||
# Loop through agents and add-up their data (batch).
|
||||
for k in keys:
|
||||
if data_col == SampleBatch.EPS_ID:
|
||||
data_list.append(self.agent_collectors[k].episode_id)
|
||||
|
@ -482,7 +579,15 @@ class _SimpleListCollector(_SampleCollector):
|
|||
self.agent_collectors[k]._build_buffers({
|
||||
data_col: fill_value
|
||||
})
|
||||
data_list.append(buffers[k][data_col][time_indices])
|
||||
if isinstance(time_indices, tuple):
|
||||
if time_indices[1] == -1:
|
||||
data_list.append(
|
||||
buffers[k][data_col][time_indices[0]:])
|
||||
else:
|
||||
data_list.append(buffers[k][data_col][time_indices[
|
||||
0]:time_indices[1] + 1])
|
||||
else:
|
||||
data_list.append(buffers[k][data_col][time_indices])
|
||||
input_dict[view_col] = np.array(data_list)
|
||||
|
||||
self._reset_inference_calls(policy_id)
|
||||
|
|
|
@ -50,8 +50,6 @@ def compute_advantages(rollout: SampleBatch,
|
|||
processed rewards.
|
||||
"""
|
||||
|
||||
rollout_size = len(rollout[SampleBatch.ACTIONS])
|
||||
|
||||
assert SampleBatch.VF_PREDS in rollout or not use_critic, \
|
||||
"use_critic=True but values not found"
|
||||
assert use_critic or not use_gae, \
|
||||
|
@ -90,6 +88,4 @@ def compute_advantages(rollout: SampleBatch,
|
|||
rollout[Postprocessing.ADVANTAGES] = rollout[
|
||||
Postprocessing.ADVANTAGES].astype(np.float32)
|
||||
|
||||
assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \
|
||||
"Rollout stacked incorrectly!"
|
||||
return rollout
|
||||
|
|
|
@ -39,6 +39,7 @@ if __name__ == "__main__":
|
|||
|
||||
config = {
|
||||
"env": args.env,
|
||||
# This env_config is only used for the RepeatAfterMeEnv env.
|
||||
"env_config": {
|
||||
"repeat_delay": 2,
|
||||
},
|
||||
|
@ -48,7 +49,7 @@ if __name__ == "__main__":
|
|||
"num_workers": 0,
|
||||
"num_envs_per_worker": 20,
|
||||
"entropy_coeff": 0.001,
|
||||
"num_sgd_iter": 5,
|
||||
"num_sgd_iter": 10,
|
||||
"vf_loss_coeff": 1e-5,
|
||||
"model": {
|
||||
"custom_model": GTrXLNet,
|
||||
|
@ -56,9 +57,10 @@ if __name__ == "__main__":
|
|||
"custom_model_config": {
|
||||
"num_transformer_units": 1,
|
||||
"attn_dim": 64,
|
||||
"num_heads": 2,
|
||||
"memory_tau": 50,
|
||||
"memory_inference": 100,
|
||||
"memory_training": 50,
|
||||
"head_dim": 32,
|
||||
"num_heads": 2,
|
||||
"ff_hidden_dim": 32,
|
||||
},
|
||||
},
|
||||
|
@ -71,7 +73,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=1)
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=1)
|
||||
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -13,7 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
|||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.spaces.repeated import Repeated
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorStructType
|
||||
from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, \
|
||||
TensorStructType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -238,14 +239,14 @@ class ModelV2:
|
|||
right input dict, state, and seq len arguments.
|
||||
"""
|
||||
|
||||
train_batch["is_training"] = is_training
|
||||
input_dict = train_batch.copy()
|
||||
input_dict["is_training"] = is_training
|
||||
states = []
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in train_batch:
|
||||
states.append(train_batch["state_in_{}".format(i)])
|
||||
while "state_in_{}".format(i) in input_dict:
|
||||
states.append(input_dict["state_in_{}".format(i)])
|
||||
i += 1
|
||||
ret = self.__call__(train_batch, states, train_batch.get("seq_lens"))
|
||||
del train_batch["is_training"]
|
||||
ret = self.__call__(input_dict, states, input_dict.get("seq_lens"))
|
||||
return ret
|
||||
|
||||
def import_from_h5(self, h5_file: str) -> None:
|
||||
|
@ -316,21 +317,57 @@ class ModelV2:
|
|||
|
||||
# TODO: (sven) Experimental method.
|
||||
def get_input_dict(self, sample_batch,
|
||||
index: int = -1) -> Dict[str, TensorType]:
|
||||
if index < 0:
|
||||
index = sample_batch.count - 1
|
||||
index: Union[int, str] = "last") -> ModelInputDict:
|
||||
"""Creates single ts input-dict at given index from a SampleBatch.
|
||||
|
||||
Args:
|
||||
sample_batch (SampleBatch): A single-trajectory SampleBatch object
|
||||
to generate the compute_actions input dict from.
|
||||
index (Union[int, str]): An integer index value indicating the
|
||||
position in the trajectory for which to generate the
|
||||
compute_actions input dict. Set to "last" to generate the dict
|
||||
at the very end of the trajectory (e.g. for value estimation).
|
||||
Note that "last" is different from -1, as "last" will use the
|
||||
final NEXT_OBS as observation input.
|
||||
|
||||
Returns:
|
||||
ModelInputDict: The (single-timestep) input dict for ModelV2 calls.
|
||||
"""
|
||||
last_mappings = {
|
||||
SampleBatch.OBS: SampleBatch.NEXT_OBS,
|
||||
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
|
||||
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
|
||||
}
|
||||
|
||||
input_dict = {}
|
||||
for view_col, view_req in self.inference_view_requirements.items():
|
||||
# Create batches of size 1 (single-agent input-dict).
|
||||
|
||||
# Index range.
|
||||
if isinstance(index, tuple):
|
||||
data = sample_batch[view_col][index[0]:index[1] + 1]
|
||||
input_dict[view_col] = np.array([data])
|
||||
# Single index.
|
||||
data_col = view_req.data_col or view_col
|
||||
if index == "last":
|
||||
data_col = last_mappings.get(data_col, data_col)
|
||||
if view_req.shift_from is not None:
|
||||
data = sample_batch[view_col][-1]
|
||||
traj_len = len(sample_batch[data_col])
|
||||
missing_at_end = traj_len % view_req.batch_repeat_value
|
||||
input_dict[view_col] = np.array([
|
||||
np.concatenate([
|
||||
data, sample_batch[data_col][-missing_at_end:]
|
||||
])[view_req.shift_from:view_req.shift_to +
|
||||
1 if view_req.shift_to != -1 else None]
|
||||
])
|
||||
else:
|
||||
data = sample_batch[data_col][-1]
|
||||
input_dict[view_col] = np.array([data])
|
||||
else:
|
||||
input_dict[view_col] = sample_batch[view_col][index:index + 1]
|
||||
# Index range.
|
||||
if isinstance(index, tuple):
|
||||
data = sample_batch[data_col][index[0]:index[1] + 1
|
||||
if index[1] != -1 else None]
|
||||
input_dict[view_col] = np.array([data])
|
||||
# Single index.
|
||||
else:
|
||||
input_dict[view_col] = sample_batch[data_col][
|
||||
index:index + 1 if index != -1 else None]
|
||||
|
||||
# Add valid `seq_lens`, just in case RNNs need it.
|
||||
input_dict["seq_lens"] = np.array([1], dtype=np.int32)
|
||||
|
|
|
@ -1,263 +0,0 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.models.tf.attention_net import relative_position_embedding, \
|
||||
GTrXLNet
|
||||
from ray.rllib.models.tf.layers import MultiHeadAttention
|
||||
from ray.rllib.models.torch.attention_net import relative_position_embedding \
|
||||
as relative_position_embedding_torch, GTrXLNet as TorchGTrXLNet
|
||||
from ray.rllib.models.torch.modules.multi_head_attention import \
|
||||
MultiHeadAttention as TorchMultiHeadAttention
|
||||
from ray.rllib.utils.framework import try_import_torch, try_import_tf
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class TestAttentionNets(unittest.TestCase):
|
||||
"""Tests various torch/modules and tf/layers required for AttentionNet"""
|
||||
|
||||
def train_torch_full_model(self,
|
||||
model,
|
||||
inputs,
|
||||
outputs,
|
||||
num_epochs=250,
|
||||
state=None,
|
||||
seq_lens=None):
|
||||
"""Convenience method that trains a Torch model for num_epochs epochs
|
||||
and tests whether loss decreased, as expected.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Torch model to be trained.
|
||||
inputs (torch.Tensor): Training data
|
||||
outputs (torch.Tensor): Training labels
|
||||
num_epochs (int): Number of epochs to train for
|
||||
state (torch.Tensor): Internal state of module
|
||||
seq_lens (torch.Tensor): Tensor of sequence lengths
|
||||
"""
|
||||
|
||||
criterion = torch.nn.MSELoss(reduction="sum")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
|
||||
|
||||
# Check that the layer trains correctly
|
||||
for t in range(num_epochs):
|
||||
y_pred = model(inputs, state, seq_lens)
|
||||
loss = criterion(y_pred[0], torch.squeeze(outputs[0]))
|
||||
|
||||
if t % 10 == 1:
|
||||
print(t, loss.item())
|
||||
|
||||
# if t == 0:
|
||||
# init_loss = loss.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# final_loss = loss.item()
|
||||
|
||||
# The final loss has decreased, which tests
|
||||
# that the model is learning from the training data.
|
||||
# self.assertLess(final_loss / init_loss, 0.99)
|
||||
|
||||
def train_torch_layer(self, model, inputs, outputs, num_epochs=250):
|
||||
"""Convenience method that trains a Torch model for num_epochs epochs
|
||||
and tests whether loss decreased, as expected.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Torch model to be trained.
|
||||
inputs (torch.Tensor): Training data
|
||||
outputs (torch.Tensor): Training labels
|
||||
num_epochs (int): Number of epochs to train for
|
||||
"""
|
||||
criterion = torch.nn.MSELoss(reduction="sum")
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
|
||||
|
||||
# Check that the layer trains correctly
|
||||
for t in range(num_epochs):
|
||||
y_pred = model(inputs)
|
||||
loss = criterion(y_pred, outputs)
|
||||
|
||||
if t == 1:
|
||||
init_loss = loss.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
final_loss = loss.item()
|
||||
|
||||
# The final loss has decreased by a factor of 2, which tests
|
||||
# that the model is learning from the training data.
|
||||
self.assertLess(final_loss / init_loss, 0.5)
|
||||
|
||||
def train_tf_model(self,
|
||||
model,
|
||||
inputs,
|
||||
outputs,
|
||||
num_epochs=250,
|
||||
minibatch_size=32):
|
||||
"""Convenience method that trains a Tensorflow model for num_epochs
|
||||
epochs and tests whether loss decreased, as expected.
|
||||
|
||||
Args:
|
||||
model (tf.Model): Torch model to be trained.
|
||||
inputs (np.array): Training data
|
||||
outputs (np.array): Training labels
|
||||
num_epochs (int): Number of training epochs
|
||||
batch_size (int): Number of samples in each minibatch
|
||||
"""
|
||||
|
||||
# Configure a model for mean-squared error loss.
|
||||
model.compile(optimizer="SGD", loss="mse", metrics=["mae"])
|
||||
|
||||
hist = model.fit(
|
||||
inputs,
|
||||
outputs,
|
||||
verbose=0,
|
||||
epochs=num_epochs,
|
||||
batch_size=minibatch_size).history
|
||||
init_loss = hist["loss"][0]
|
||||
final_loss = hist["loss"][-1]
|
||||
|
||||
self.assertLess(final_loss / init_loss, 0.5)
|
||||
|
||||
def test_multi_head_attention(self):
|
||||
"""Tests the MultiHeadAttention mechanism of Vaswani et al."""
|
||||
# B is batch size
|
||||
B = 1
|
||||
# D_in is attention dim, L is memory_tau
|
||||
L, D_in, D_out = 2, 32, 10
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("tfe", "torch", "tf"), session=True):
|
||||
# Create a single attention layer with 2 heads.
|
||||
if fw == "torch":
|
||||
|
||||
# Create random Tensors to hold inputs and outputs
|
||||
x = torch.randn(B, L, D_in)
|
||||
y = torch.randn(B, L, D_out)
|
||||
|
||||
model = TorchMultiHeadAttention(
|
||||
in_dim=D_in, out_dim=D_out, num_heads=2, head_dim=32)
|
||||
|
||||
self.train_torch_layer(model, x, y, num_epochs=500)
|
||||
|
||||
# Framework is tensorflow or tensorflow-eager.
|
||||
else:
|
||||
x = np.random.random((B, L, D_in))
|
||||
y = np.random.random((B, L, D_out))
|
||||
|
||||
inputs = tf.keras.layers.Input(shape=(L, D_in))
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
inputs,
|
||||
MultiHeadAttention(
|
||||
out_dim=D_out, num_heads=2, head_dim=32)
|
||||
])
|
||||
self.train_tf_model(model, x, y)
|
||||
|
||||
def test_attention_net(self):
|
||||
"""Tests the GTrXL.
|
||||
|
||||
Builds a full AttentionNet and checks that it trains in a supervised
|
||||
setting."""
|
||||
|
||||
# Checks that torch and tf embedding matrices are the same
|
||||
with tf1.Session().as_default() as sess:
|
||||
assert np.allclose(
|
||||
relative_position_embedding(20, 15).eval(session=sess),
|
||||
relative_position_embedding_torch(20, 15).numpy())
|
||||
|
||||
# B is batch size
|
||||
B = 32
|
||||
# D_in is attention dim, L is memory_tau
|
||||
L, D_in, D_out = 2, 16, 2
|
||||
|
||||
for fw, sess in framework_iterator(session=True):
|
||||
|
||||
# Create a single attention layer with 2 heads
|
||||
if fw == "torch":
|
||||
# Create random Tensors to hold inputs and outputs
|
||||
x = torch.randn(B, L, D_in)
|
||||
y = torch.randn(B, L, D_out)
|
||||
|
||||
value_labels = torch.randn(B, L, D_in)
|
||||
memory_labels = torch.randn(B, L, D_out)
|
||||
|
||||
attention_net = TorchGTrXLNet(
|
||||
observation_space=gym.spaces.Box(
|
||||
low=float("-inf"), high=float("inf"), shape=(D_in, )),
|
||||
action_space=gym.spaces.Discrete(D_out),
|
||||
num_outputs=D_out,
|
||||
model_config={"max_seq_len": 2},
|
||||
name="TestTorchAttentionNet",
|
||||
num_transformer_units=2,
|
||||
attn_dim=D_in,
|
||||
num_heads=2,
|
||||
memory_tau=L,
|
||||
head_dim=D_out,
|
||||
ff_hidden_dim=16,
|
||||
init_gate_bias=2.0)
|
||||
|
||||
init_state = attention_net.get_initial_state()
|
||||
|
||||
# Get initial state and add a batch dimension.
|
||||
init_state = [np.expand_dims(s, 0) for s in init_state]
|
||||
seq_lens_init = torch.full(
|
||||
size=(B, ), fill_value=L, dtype=torch.int32)
|
||||
|
||||
# Torch implementation expects a formatted input_dict instead
|
||||
# of a numpy array as input.
|
||||
input_dict = {"obs": x}
|
||||
self.train_torch_full_model(
|
||||
attention_net,
|
||||
input_dict, [y, value_labels, memory_labels],
|
||||
num_epochs=250,
|
||||
state=init_state,
|
||||
seq_lens=seq_lens_init)
|
||||
# Framework is tensorflow or tensorflow-eager.
|
||||
else:
|
||||
x = np.random.random((B, L, D_in))
|
||||
y = np.random.random((B, L, D_out))
|
||||
|
||||
value_labels = np.random.random((B, L, 1))
|
||||
memory_labels = np.random.random((B, L, D_in))
|
||||
|
||||
# We need to create (N-1) MLP labels for N transformer units
|
||||
mlp_labels = np.random.random((B, L, D_in))
|
||||
|
||||
attention_net = GTrXLNet(
|
||||
observation_space=gym.spaces.Box(
|
||||
low=float("-inf"), high=float("inf"), shape=(D_in, )),
|
||||
action_space=gym.spaces.Discrete(D_out),
|
||||
num_outputs=D_out,
|
||||
model_config={"max_seq_len": 2},
|
||||
name="TestTFAttentionNet",
|
||||
num_transformer_units=2,
|
||||
attn_dim=D_in,
|
||||
num_heads=2,
|
||||
memory_tau=L,
|
||||
head_dim=D_out,
|
||||
ff_hidden_dim=16,
|
||||
init_gate_bias=2.0)
|
||||
model = attention_net.trxl_model
|
||||
|
||||
# Get initial state and add a batch dimension.
|
||||
init_state = attention_net.get_initial_state()
|
||||
init_state = [np.tile(s, (B, 1, 1)) for s in init_state]
|
||||
|
||||
self.train_tf_model(
|
||||
model, [x] + init_state,
|
||||
[y, value_labels, memory_labels, mlp_labels],
|
||||
num_epochs=200,
|
||||
minibatch_size=B)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -8,14 +8,17 @@
|
|||
Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
|
||||
https://www.aclweb.org/anthology/P19-1285.pdf
|
||||
"""
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import gym
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
|
||||
SkipConnection
|
||||
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
||||
|
@ -60,7 +63,7 @@ class TrXLNet(RecurrentNetwork):
|
|||
model_config: ModelConfigDict, name: str,
|
||||
num_transformer_units: int, attn_dim: int, num_heads: int,
|
||||
head_dim: int, ff_hidden_dim: int):
|
||||
"""Initializes a TfXLNet object.
|
||||
"""Initializes a TrXLNet object.
|
||||
|
||||
Args:
|
||||
num_transformer_units (int): The number of Transformer repeats to
|
||||
|
@ -88,8 +91,6 @@ class TrXLNet(RecurrentNetwork):
|
|||
self.max_seq_len = model_config["max_seq_len"]
|
||||
self.obs_dim = observation_space.shape[0]
|
||||
|
||||
pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim)
|
||||
|
||||
inputs = tf.keras.layers.Input(
|
||||
shape=(self.max_seq_len, self.obs_dim), name="inputs")
|
||||
E_out = tf.keras.layers.Dense(attn_dim)(inputs)
|
||||
|
@ -100,7 +101,6 @@ class TrXLNet(RecurrentNetwork):
|
|||
out_dim=attn_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
rel_pos_encoder=pos_embedding,
|
||||
input_layernorm=False,
|
||||
output_activation=None),
|
||||
fan_in_layer=None)(E_out)
|
||||
|
@ -160,7 +160,8 @@ class GTrXLNet(RecurrentNetwork):
|
|||
>> num_transformer_units=1,
|
||||
>> attn_dim=32,
|
||||
>> num_heads=2,
|
||||
>> memory_tau=50,
|
||||
>> memory_inference=100,
|
||||
>> memory_training=50,
|
||||
>> etc..
|
||||
>> }
|
||||
"""
|
||||
|
@ -174,11 +175,12 @@ class GTrXLNet(RecurrentNetwork):
|
|||
num_transformer_units: int,
|
||||
attn_dim: int,
|
||||
num_heads: int,
|
||||
memory_tau: int,
|
||||
memory_inference: int,
|
||||
memory_training: int,
|
||||
head_dim: int,
|
||||
ff_hidden_dim: int,
|
||||
init_gate_bias: float = 2.0):
|
||||
"""Initializes a GTrXLNet.
|
||||
"""Initializes a GTrXLNet instance.
|
||||
|
||||
Args:
|
||||
num_transformer_units (int): The number of Transformer repeats to
|
||||
|
@ -187,9 +189,15 @@ class GTrXLNet(RecurrentNetwork):
|
|||
unit.
|
||||
num_heads (int): The number of attention heads to use in parallel.
|
||||
Denoted as `H` in [3].
|
||||
memory_tau (int): The number of timesteps to store in each
|
||||
transformer block's memory M (concat'd over time and fed into
|
||||
next transformer block as input).
|
||||
memory_inference (int): The number of timesteps to concat (time
|
||||
axis) and feed into the next transformer unit as inference
|
||||
input. The first transformer unit will receive this number of
|
||||
past observations (plus the current one), instead.
|
||||
memory_training (int): The number of timesteps to concat (time
|
||||
axis) and feed into the next transformer unit as training
|
||||
input (plus the actual input sequence of len=max_seq_len).
|
||||
The first transformer unit will receive this number of
|
||||
past observations (plus the input sequence), instead.
|
||||
head_dim (int): The dimension of a single(!) head.
|
||||
Denoted as `d` in [3].
|
||||
ff_hidden_dim (int): The dimension of the hidden layer within
|
||||
|
@ -208,21 +216,18 @@ class GTrXLNet(RecurrentNetwork):
|
|||
self.num_transformer_units = num_transformer_units
|
||||
self.attn_dim = attn_dim
|
||||
self.num_heads = num_heads
|
||||
self.memory_tau = memory_tau
|
||||
self.memory_inference = memory_inference
|
||||
self.memory_training = memory_training
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = model_config["max_seq_len"]
|
||||
self.obs_dim = observation_space.shape[0]
|
||||
|
||||
# Constant (non-trainable) sinusoid rel pos encoding matrix.
|
||||
Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
|
||||
self.attn_dim)
|
||||
|
||||
# Raw observation input.
|
||||
# Raw observation input (plus (None) time axis).
|
||||
input_layer = tf.keras.layers.Input(
|
||||
shape=(self.max_seq_len, self.obs_dim), name="inputs")
|
||||
shape=(None, self.obs_dim), name="inputs")
|
||||
memory_ins = [
|
||||
tf.keras.layers.Input(
|
||||
shape=(self.memory_tau, self.attn_dim),
|
||||
shape=(None, self.attn_dim),
|
||||
dtype=tf.float32,
|
||||
name="memory_in_{}".format(i))
|
||||
for i in range(self.num_transformer_units)
|
||||
|
@ -242,7 +247,6 @@ class GTrXLNet(RecurrentNetwork):
|
|||
out_dim=self.attn_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
rel_pos_encoder=Phi,
|
||||
input_layernorm=True,
|
||||
output_activation=tf.nn.relu),
|
||||
fan_in_layer=GRUGate(init_gate_bias),
|
||||
|
@ -280,69 +284,52 @@ class GTrXLNet(RecurrentNetwork):
|
|||
self.register_variables(self.trxl_model.variables)
|
||||
self.trxl_model.summary()
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
# To make Attention work with current RLlib's ModelV2 API:
|
||||
# We assume `state` is the history of L recent observations (all
|
||||
# concatenated into one tensor) and append the current inputs to the
|
||||
# end and only keep the most recent (up to `max_seq_len`). This allows
|
||||
# us to deal with timestep-wise inference and full sequence training
|
||||
# within the same logic.
|
||||
observations = state[0]
|
||||
memory = state[1:]
|
||||
# Setup inference view (`memory-inference` x past observations +
|
||||
# current one (0))
|
||||
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
|
||||
for i in range(self.num_transformer_units):
|
||||
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
|
||||
self.inference_view_requirements["state_in_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
"state_out_{}".format(i),
|
||||
shift="-{}:-1".format(self.memory_inference),
|
||||
# Repeat the incoming state every max-seq-len times.
|
||||
batch_repeat_value=self.max_seq_len,
|
||||
space=space)
|
||||
self.inference_view_requirements["state_out_{}".format(i)] = \
|
||||
ViewRequirement(
|
||||
space=space,
|
||||
used_for_training=False)
|
||||
|
||||
observations = tf.concat(
|
||||
(observations, inputs), axis=1)[:, -self.max_seq_len:]
|
||||
all_out = self.trxl_model([observations] + memory)
|
||||
logits, self._value_out = all_out[0], all_out[1]
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state: List[TensorType],
|
||||
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||
assert seq_lens is not None
|
||||
|
||||
# Add the time dim to observations.
|
||||
B = tf.shape(seq_lens)[0]
|
||||
observations = input_dict[SampleBatch.OBS]
|
||||
|
||||
shape = tf.shape(observations)
|
||||
T = shape[0] // B
|
||||
observations = tf.reshape(observations,
|
||||
tf.concat([[-1, T], shape[1:]], axis=0))
|
||||
|
||||
all_out = self.trxl_model([observations] + state)
|
||||
|
||||
logits = all_out[0]
|
||||
self._value_out = all_out[1]
|
||||
memory_outs = all_out[2:]
|
||||
# If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
|
||||
if self.memory_tau > self.max_seq_len:
|
||||
memory_outs = [
|
||||
tf.concat(
|
||||
[memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
|
||||
axis=1) for i, m in enumerate(memory_outs)
|
||||
]
|
||||
else:
|
||||
memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]
|
||||
|
||||
T = tf.shape(inputs)[1] # Length of input segment (time).
|
||||
logits = logits[:, -T:]
|
||||
self._value_out = self._value_out[:, -T:]
|
||||
|
||||
return logits, [observations] + memory_outs
|
||||
return tf.reshape(logits, [-1, self.num_outputs]), [
|
||||
tf.reshape(m, [-1, self.attn_dim]) for m in memory_outs
|
||||
]
|
||||
|
||||
# TODO: (sven) Deprecate this once trajectory view API has fully matured.
|
||||
@override(RecurrentNetwork)
|
||||
def get_initial_state(self) -> List[np.ndarray]:
|
||||
# State is the T last observations concat'd together into one Tensor.
|
||||
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
|
||||
# tau timesteps).
|
||||
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \
|
||||
[np.zeros((self.memory_tau, self.attn_dim), np.float32)
|
||||
for _ in range(self.num_transformer_units)]
|
||||
return []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self) -> TensorType:
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType:
|
||||
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
|
||||
|
||||
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
|
||||
matrix.
|
||||
|
||||
Args:
|
||||
seq_length (int): The max. sequence length (time axis).
|
||||
out_dim (int): The number of nodes to go into the first Tranformer
|
||||
layer with.
|
||||
|
||||
Returns:
|
||||
tf.Tensor: The encoding matrix Phi.
|
||||
"""
|
||||
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
|
||||
pos_offsets = tf.range(seq_length - 1., -1., -1.)
|
||||
inputs = pos_offsets[:, None] * inverse_freq[None, :]
|
||||
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from ray.rllib.models.tf.layers.gru_gate import GRUGate
|
||||
from ray.rllib.models.tf.layers.noisy_layer import NoisyLayer
|
||||
from ray.rllib.models.tf.layers.relative_multi_head_attention import \
|
||||
RelativeMultiHeadAttention
|
||||
PositionalEmbedding, RelativeMultiHeadAttention
|
||||
from ray.rllib.models.tf.layers.skip_connection import SkipConnection
|
||||
from ray.rllib.models.tf.layers.multi_head_attention import MultiHeadAttention
|
||||
|
||||
__all__ = [
|
||||
"GRUGate", "MultiHeadAttention", "NoisyLayer",
|
||||
"GRUGate", "MultiHeadAttention", "NoisyLayer", "PositionalEmbedding",
|
||||
"RelativeMultiHeadAttention", "SkipConnection"
|
||||
]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
@ -16,9 +16,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
out_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
rel_pos_encoder: Any,
|
||||
input_layernorm: bool = False,
|
||||
output_activation: Optional[Any] = None,
|
||||
output_activation: Optional["tf.nn.activation"] = None,
|
||||
**kwargs):
|
||||
"""Initializes a RelativeMultiHeadAttention keras Layer object.
|
||||
|
||||
|
@ -28,7 +27,6 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
Denoted `H` in [2].
|
||||
head_dim (int): The dimension of a single(!) attention head
|
||||
Denoted `D` in [2].
|
||||
rel_pos_encoder (:
|
||||
input_layernorm (bool): Whether to prepend a LayerNorm before
|
||||
everything else. Should be True for building a GTrXL.
|
||||
output_activation (Optional[tf.nn.activation]): Optional tf.nn
|
||||
|
@ -50,9 +48,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
self._uvar = self.add_weight(shape=(num_heads, head_dim))
|
||||
self._vvar = self.add_weight(shape=(num_heads, head_dim))
|
||||
|
||||
# Constant (non-trainable) sinusoid rel pos encoding matrix, which
|
||||
# depends on this incoming time dimension.
|
||||
# For inference, we prepend the memory to the current timestep's
|
||||
# input: Tau + 1. For training, we prepend the memory to the input
|
||||
# sequence: Tau + T.
|
||||
self._pos_embedding = PositionalEmbedding(out_dim)
|
||||
self._pos_proj = tf.keras.layers.Dense(
|
||||
num_heads * head_dim, use_bias=False)
|
||||
self._rel_pos_encoder = rel_pos_encoder
|
||||
|
||||
self._input_layernorm = None
|
||||
if input_layernorm:
|
||||
|
@ -66,9 +69,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
|
||||
# Add previous memory chunk (as const, w/o gradient) to input.
|
||||
# Tau (number of (prev) time slices in each memory chunk).
|
||||
Tau = memory.shape.as_list()[1] if memory is not None else 0
|
||||
if memory is not None:
|
||||
inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1)
|
||||
Tau = tf.shape(memory)[1]
|
||||
inputs = tf.concat([tf.stop_gradient(memory), inputs], axis=1)
|
||||
|
||||
# Apply the Layer-Norm.
|
||||
if self._input_layernorm is not None:
|
||||
|
@ -77,15 +79,17 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
qkv = self._qkv_layer(inputs)
|
||||
|
||||
queries, keys, values = tf.split(qkv, 3, -1)
|
||||
# Cut out Tau memory timesteps from query.
|
||||
# Cut out memory timesteps from query.
|
||||
queries = queries[:, -T:]
|
||||
|
||||
# Splitting up queries into per-head dims (d).
|
||||
queries = tf.reshape(queries, [-1, T, H, d])
|
||||
keys = tf.reshape(keys, [-1, T + Tau, H, d])
|
||||
values = tf.reshape(values, [-1, T + Tau, H, d])
|
||||
keys = tf.reshape(keys, [-1, Tau + T, H, d])
|
||||
values = tf.reshape(values, [-1, Tau + T, H, d])
|
||||
|
||||
R = self._pos_proj(self._rel_pos_encoder)
|
||||
R = tf.reshape(R, [T + Tau, H, d])
|
||||
R = self._pos_embedding(Tau + T)
|
||||
R = self._pos_proj(R)
|
||||
R = tf.reshape(R, [Tau + T, H, d])
|
||||
|
||||
# b=batch
|
||||
# i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
|
||||
|
@ -96,9 +100,9 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
score = score + self.rel_shift(pos_score)
|
||||
score = score / d**0.5
|
||||
|
||||
# causal mask of the same length as the sequence
|
||||
# Causal mask of the same length as the sequence.
|
||||
mask = tf.sequence_mask(
|
||||
tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype)
|
||||
tf.range(Tau + 1, Tau + T + 1), dtype=score.dtype)
|
||||
mask = mask[None, :, :, None]
|
||||
|
||||
masked_score = score * mask + 1e30 * (mask - 1.)
|
||||
|
@ -121,3 +125,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|||
x = tf.reshape(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PositionalEmbedding(tf.keras.layers.Layer if tf else object):
|
||||
def __init__(self, out_dim, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
|
||||
|
||||
def call(self, seq_length):
|
||||
pos_offsets = tf.cast(tf.range(seq_length - 1, -1, -1), tf.float32)
|
||||
inputs = pos_offsets[:, None] * self.inverse_freq[None, :]
|
||||
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
|
||||
|
|
|
@ -16,7 +16,6 @@ class SkipConnection(tf.keras.layers.Layer if tf else object):
|
|||
def __init__(self,
|
||||
layer: Any,
|
||||
fan_in_layer: Optional[Any] = None,
|
||||
add_memory: bool = False,
|
||||
**kwargs):
|
||||
"""Initializes a SkipConnection keras layer object.
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ class SkipConnection(nn.Module):
|
|||
def __init__(self,
|
||||
layer: nn.Module,
|
||||
fan_in_layer: Optional[nn.Module] = None,
|
||||
add_memory: bool = False,
|
||||
**kwargs):
|
||||
"""Initializes a SkipConnection nn Module object.
|
||||
|
||||
|
|
|
@ -183,11 +183,12 @@ class DynamicTFPolicy(TFPolicy):
|
|||
else:
|
||||
if self.config["_use_trajectory_view_api"]:
|
||||
self._state_inputs = [
|
||||
tf1.placeholder(
|
||||
shape=(None, ) + vr.space.shape, dtype=vr.space.dtype)
|
||||
for k, vr in
|
||||
get_placeholder(
|
||||
space=vr.space,
|
||||
time_axis=not isinstance(vr.shift, int),
|
||||
) for k, vr in
|
||||
self.model.inference_view_requirements.items()
|
||||
if k[:9] == "state_in_"
|
||||
if k.startswith("state_in_")
|
||||
]
|
||||
else:
|
||||
self._state_inputs = [
|
||||
|
@ -423,9 +424,14 @@ class DynamicTFPolicy(TFPolicy):
|
|||
input_dict[view_col] = existing_inputs[view_col]
|
||||
# All others.
|
||||
else:
|
||||
time_axis = not isinstance(view_req.shift, int)
|
||||
if view_req.used_for_training:
|
||||
# Create a +time-axis placeholder if the shift is not an
|
||||
# int (range or list of ints).
|
||||
input_dict[view_col] = get_placeholder(
|
||||
space=view_req.space, name=view_col)
|
||||
space=view_req.space,
|
||||
name=view_col,
|
||||
time_axis=time_axis)
|
||||
dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||
batch_size=32)
|
||||
|
||||
|
@ -490,10 +496,10 @@ class DynamicTFPolicy(TFPolicy):
|
|||
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
|
||||
for k, v in self.extra_compute_action_fetches().items():
|
||||
dummy_batch[k] = fake_array(v)
|
||||
dummy_batch = SampleBatch(dummy_batch)
|
||||
|
||||
sb = SampleBatch(dummy_batch)
|
||||
batch_for_postproc = UsageTrackingDict(sb)
|
||||
batch_for_postproc.count = sb.count
|
||||
batch_for_postproc = UsageTrackingDict(dummy_batch)
|
||||
batch_for_postproc.count = dummy_batch.count
|
||||
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
||||
self.exploration.postprocess_trajectory(self, batch_for_postproc,
|
||||
self._sess)
|
||||
|
@ -519,6 +525,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
train_batch.update({
|
||||
SampleBatch.PREV_ACTIONS: self._prev_action_input,
|
||||
SampleBatch.PREV_REWARDS: self._prev_reward_input,
|
||||
SampleBatch.CUR_OBS: self._obs_input,
|
||||
})
|
||||
|
||||
for k, v in postprocessed_batch.items():
|
||||
|
@ -578,7 +585,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
for key in batch_for_postproc.accessed_keys:
|
||||
if key not in train_batch.accessed_keys and \
|
||||
key not in self.model.inference_view_requirements:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
if key in self.view_requirements:
|
||||
self.view_requirements[key].used_for_training = False
|
||||
if key in self._loss_input_dict:
|
||||
del self._loss_input_dict[key]
|
||||
# Remove those not needed at all (leave those that are needed
|
||||
|
|
|
@ -314,12 +314,16 @@ def build_eager_tf_policy(name,
|
|||
self.callbacks.on_learn_on_batch(
|
||||
policy=self, train_batch=postprocessed_batch)
|
||||
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
shuffle=False,
|
||||
max_seq_len=self._max_seq_len,
|
||||
batch_divisibility_req=self.batch_divisibility_req)
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
self._is_training = True
|
||||
postprocessed_batch["is_training"] = True
|
||||
return self._learn_on_batch_eager(postprocessed_batch)
|
||||
|
||||
@convert_eager_inputs
|
||||
|
@ -332,12 +336,14 @@ def build_eager_tf_policy(name,
|
|||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, samples):
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
samples,
|
||||
shuffle=False,
|
||||
max_seq_len=self._max_seq_len,
|
||||
batch_divisibility_req=self.batch_divisibility_req)
|
||||
|
||||
self._is_training = True
|
||||
samples["is_training"] = True
|
||||
return self._compute_gradients_eager(samples)
|
||||
|
||||
@convert_eager_inputs
|
||||
|
@ -369,7 +375,7 @@ def build_eager_tf_policy(name,
|
|||
|
||||
# TODO: remove python side effect to cull sources of bugs.
|
||||
self._is_training = False
|
||||
self._state_in = state_batches
|
||||
self._state_in = state_batches or []
|
||||
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
|
@ -591,8 +597,6 @@ def build_eager_tf_policy(name,
|
|||
def _compute_gradients(self, samples):
|
||||
"""Computes and returns grads as eager tensors."""
|
||||
|
||||
self._is_training = True
|
||||
|
||||
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
||||
loss = loss_fn(self, self.model, self.dist_class, samples)
|
||||
|
||||
|
|
|
@ -629,10 +629,9 @@ class Policy(metaclass=ABCMeta):
|
|||
batch_for_postproc.count = self._dummy_batch.count
|
||||
self.exploration.postprocess_trajectory(self, batch_for_postproc)
|
||||
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
|
||||
seq_lens = None
|
||||
if state_outs:
|
||||
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
|
||||
# TODO: (sven) This hack will not work for attention net traj.
|
||||
# view setup.
|
||||
i = 0
|
||||
while "state_in_{}".format(i) in postprocessed_batch:
|
||||
postprocessed_batch["state_in_{}".format(i)] = \
|
||||
|
@ -642,12 +641,11 @@ class Policy(metaclass=ABCMeta):
|
|||
postprocessed_batch["state_out_{}".format(i)][:B]
|
||||
i += 1
|
||||
seq_len = sample_batch_size // B
|
||||
postprocessed_batch["seq_lens"] = \
|
||||
np.array([seq_len for _ in range(B)], dtype=np.int32)
|
||||
# Remove the UsageTrackingDict wrap to prep for wrapping the
|
||||
# train batch with a to-tensor UsageTrackingDict.
|
||||
train_batch = {k: v for k, v in postprocessed_batch.items()}
|
||||
train_batch = self._lazy_tensor_dict(train_batch)
|
||||
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
|
||||
# Wrap `train_batch` with a to-tensor UsageTrackingDict.
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
if seq_lens is not None:
|
||||
train_batch["seq_lens"] = seq_lens
|
||||
train_batch.count = self._dummy_batch.count
|
||||
# Call the loss function, if it exists.
|
||||
if self._loss is not None:
|
||||
|
@ -712,13 +710,33 @@ class Policy(metaclass=ABCMeta):
|
|||
ret[view_col] = \
|
||||
np.zeros((batch_size, ) + shape[1:], np.float32)
|
||||
else:
|
||||
if isinstance(view_req.space, gym.spaces.Space):
|
||||
ret[view_col] = np.zeros_like(
|
||||
[view_req.space.sample() for _ in range(batch_size)])
|
||||
# Range of indices on time-axis, e.g. "-50:-1".
|
||||
if view_req.shift_from is not None:
|
||||
ret[view_col] = np.zeros_like([[
|
||||
view_req.space.sample()
|
||||
for _ in range(view_req.shift_to -
|
||||
view_req.shift_from + 1)
|
||||
] for _ in range(batch_size)])
|
||||
# Set of (probably non-consecutive) indices.
|
||||
elif isinstance(view_req.shift, (list, tuple)):
|
||||
ret[view_col] = np.zeros_like([[
|
||||
view_req.space.sample()
|
||||
for t in range(len(view_req.shift))
|
||||
] for _ in range(batch_size)])
|
||||
# Single shift int value.
|
||||
else:
|
||||
ret[view_col] = [view_req.space for _ in range(batch_size)]
|
||||
if isinstance(view_req.space, gym.spaces.Space):
|
||||
ret[view_col] = np.zeros_like([
|
||||
view_req.space.sample() for _ in range(batch_size)
|
||||
])
|
||||
else:
|
||||
ret[view_col] = [
|
||||
view_req.space for _ in range(batch_size)
|
||||
]
|
||||
|
||||
return SampleBatch(ret)
|
||||
# Due to different view requirements for the different columns,
|
||||
# columns in the resulting batch may not all have the same batch size.
|
||||
return SampleBatch(ret, _dont_check_lens=True)
|
||||
|
||||
def _update_model_inference_view_requirements_from_init_state(self):
|
||||
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
|
||||
|
@ -737,8 +755,13 @@ class Policy(metaclass=ABCMeta):
|
|||
view_reqs = model.inference_view_requirements if model else \
|
||||
self.view_requirements
|
||||
view_reqs["state_in_{}".format(i)] = ViewRequirement(
|
||||
"state_out_{}".format(i), shift=-1, space=space)
|
||||
view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space)
|
||||
"state_out_{}".format(i),
|
||||
shift=-1,
|
||||
batch_repeat_value=self.config.get("model", {}).get(
|
||||
"max_seq_len", 1),
|
||||
space=space)
|
||||
view_reqs["state_out_{}".format(i)] = ViewRequirement(
|
||||
space=space, used_for_training=True)
|
||||
|
||||
|
||||
def clip_action(action, action_space):
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
|
||||
from ray.util import log_once
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -35,6 +35,7 @@ def pad_batch_to_sequences_of_same_size(
|
|||
shuffle: bool = False,
|
||||
batch_divisibility_req: int = 1,
|
||||
feature_keys: Optional[List[str]] = None,
|
||||
view_requirements: Optional[ViewRequirementsDict] = None,
|
||||
):
|
||||
"""Applies padding to `batch` so it's choppable into same-size sequences.
|
||||
|
||||
|
@ -55,6 +56,9 @@ def pad_batch_to_sequences_of_same_size(
|
|||
feature_keys (Optional[List[str]]): An optional list of keys to apply
|
||||
sequence-chopping to. If None, use all keys in batch that are not
|
||||
"state_in/out_"-type keys.
|
||||
view_requirements (Optional[ViewRequirementsDict]): An optional
|
||||
Policy ViewRequirements dict to be able to infer whether
|
||||
e.g. dynamic max'ing should be applied over the seq_lens.
|
||||
"""
|
||||
if batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
|
@ -64,46 +68,65 @@ def pad_batch_to_sequences_of_same_size(
|
|||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
# RNN-case.
|
||||
states_already_reduced_to_init = False
|
||||
|
||||
# RNN/attention net case. Figure out whether we should apply dynamic
|
||||
# max'ing over the list of sequence lengths.
|
||||
if "state_in_0" in batch or "state_out_0" in batch:
|
||||
dynamic_max = True
|
||||
# Check, whether the state inputs have already been reduced to their
|
||||
# init values at the beginning of each max_seq_len chunk.
|
||||
if batch.seq_lens is not None and \
|
||||
len(batch["state_in_0"]) == len(batch.seq_lens):
|
||||
states_already_reduced_to_init = True
|
||||
|
||||
# RNN (or single timestep state-in): Set the max dynamically.
|
||||
if view_requirements["state_in_0"].shift_from is None:
|
||||
dynamic_max = True
|
||||
# Attention Nets (state inputs are over some range): No dynamic maxing
|
||||
# possible.
|
||||
else:
|
||||
dynamic_max = False
|
||||
# Multi-agent case.
|
||||
elif not meets_divisibility_reqs:
|
||||
max_seq_len = batch_divisibility_req
|
||||
dynamic_max = False
|
||||
# Simple case: not RNN nor do we need to pad.
|
||||
# Simple case: No RNN/attention net, nor do we need to pad.
|
||||
else:
|
||||
if shuffle:
|
||||
batch.shuffle()
|
||||
return
|
||||
|
||||
# RNN or multi-agent case.
|
||||
# RNN, attention net, or multi-agent case.
|
||||
state_keys = []
|
||||
feature_keys_ = feature_keys or []
|
||||
for k in batch.keys():
|
||||
if "state_in_" in k:
|
||||
for k, v in batch.items():
|
||||
if k.startswith("state_in_"):
|
||||
state_keys.append(k)
|
||||
elif not feature_keys and "state_out_" not in k and k != "infos":
|
||||
elif not feature_keys and not k.startswith("state_out_") and \
|
||||
k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray):
|
||||
feature_keys_.append(k)
|
||||
|
||||
feature_sequences, initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.UNROLL_ID],
|
||||
batch[SampleBatch.AGENT_INDEX],
|
||||
[batch[k] for k in feature_keys_],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
feature_columns=[batch[k] for k in feature_keys_],
|
||||
state_columns=[batch[k] for k in state_keys],
|
||||
episode_ids=batch.get(SampleBatch.EPS_ID),
|
||||
unroll_ids=batch.get(SampleBatch.UNROLL_ID),
|
||||
agent_indices=batch.get(SampleBatch.AGENT_INDEX),
|
||||
seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")),
|
||||
max_seq_len=max_seq_len,
|
||||
dynamic_max=dynamic_max,
|
||||
states_already_reduced_to_init=states_already_reduced_to_init,
|
||||
shuffle=shuffle)
|
||||
|
||||
for i, k in enumerate(feature_keys_):
|
||||
batch[k] = feature_sequences[i]
|
||||
for i, k in enumerate(state_keys):
|
||||
batch[k] = initial_states[i]
|
||||
batch["seq_lens"] = seq_lens
|
||||
batch["seq_lens"] = np.array(seq_lens)
|
||||
|
||||
if log_once("rnn_ma_feed_dict"):
|
||||
logger.info("Padded input for RNN:\n\n{}\n".format(
|
||||
logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
|
||||
summarize({
|
||||
"features": feature_sequences,
|
||||
"initial_states": initial_states,
|
||||
|
@ -157,18 +180,18 @@ def add_time_dimension(padded_inputs: TensorType,
|
|||
return torch.reshape(padded_inputs, new_shape)
|
||||
|
||||
|
||||
# NOTE: This function will be deprecated once chunks already come padded and
|
||||
# correctly chopped from the _SampleCollector object (in time-major fashion
|
||||
# or not). It is already no longer user iff `_use_trajectory_view_api` = True.
|
||||
@DeveloperAPI
|
||||
def chop_into_sequences(episode_ids,
|
||||
unroll_ids,
|
||||
agent_indices,
|
||||
def chop_into_sequences(*,
|
||||
feature_columns,
|
||||
state_columns,
|
||||
max_seq_len,
|
||||
episode_ids=None,
|
||||
unroll_ids=None,
|
||||
agent_indices=None,
|
||||
dynamic_max=True,
|
||||
shuffle=False,
|
||||
seq_lens=None,
|
||||
states_already_reduced_to_init=False,
|
||||
_extra_padding=0):
|
||||
"""Truncate and pad experiences into fixed-length sequences.
|
||||
|
||||
|
@ -212,23 +235,24 @@ def chop_into_sequences(episode_ids,
|
|||
[2, 3, 1]
|
||||
"""
|
||||
|
||||
prev_id = None
|
||||
seq_lens = []
|
||||
seq_len = 0
|
||||
unique_ids = np.add(
|
||||
np.add(episode_ids, agent_indices),
|
||||
np.array(unroll_ids, dtype=np.int64) << 32)
|
||||
for uid in unique_ids:
|
||||
if (prev_id is not None and uid != prev_id) or \
|
||||
seq_len >= max_seq_len:
|
||||
if seq_lens is None or len(seq_lens) == 0:
|
||||
prev_id = None
|
||||
seq_lens = []
|
||||
seq_len = 0
|
||||
unique_ids = np.add(
|
||||
np.add(episode_ids, agent_indices),
|
||||
np.array(unroll_ids, dtype=np.int64) << 32)
|
||||
for uid in unique_ids:
|
||||
if (prev_id is not None and uid != prev_id) or \
|
||||
seq_len >= max_seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
seq_len = 0
|
||||
seq_len += 1
|
||||
prev_id = uid
|
||||
if seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
seq_len = 0
|
||||
seq_len += 1
|
||||
prev_id = uid
|
||||
if seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
assert sum(seq_lens) == len(unique_ids)
|
||||
seq_lens = np.array(seq_lens, dtype=np.int32)
|
||||
seq_lens = np.array(seq_lens, dtype=np.int32)
|
||||
assert sum(seq_lens) == len(feature_columns[0])
|
||||
|
||||
# Dynamically shrink max len as needed to optimize memory usage
|
||||
if dynamic_max:
|
||||
|
@ -252,18 +276,23 @@ def chop_into_sequences(episode_ids,
|
|||
f_pad[seq_base + seq_offset] = f[i]
|
||||
i += 1
|
||||
seq_base += max_seq_len
|
||||
assert i == len(unique_ids), f
|
||||
assert i == len(f), f
|
||||
feature_sequences.append(f_pad)
|
||||
|
||||
initial_states = []
|
||||
for s in state_columns:
|
||||
s = np.array(s)
|
||||
s_init = []
|
||||
i = 0
|
||||
for len_ in seq_lens:
|
||||
s_init.append(s[i])
|
||||
i += len_
|
||||
initial_states.append(np.array(s_init))
|
||||
if states_already_reduced_to_init:
|
||||
initial_states = state_columns
|
||||
else:
|
||||
initial_states = []
|
||||
for s in state_columns:
|
||||
# Skip unnecessary copy.
|
||||
if not isinstance(s, np.ndarray):
|
||||
s = np.array(s)
|
||||
s_init = []
|
||||
i = 0
|
||||
for len_ in seq_lens:
|
||||
s_init.append(s[i])
|
||||
i += len_
|
||||
initial_states.append(np.array(s_init))
|
||||
|
||||
if shuffle:
|
||||
permutation = np.random.permutation(len(seq_lens))
|
||||
|
|
|
@ -61,6 +61,7 @@ class SampleBatch:
|
|||
# Possible seq_lens (TxB or BxT) setup.
|
||||
self.time_major = kwargs.pop("_time_major", None)
|
||||
self.seq_lens = kwargs.pop("_seq_lens", None)
|
||||
self.dont_check_lens = kwargs.pop("_dont_check_lens", False)
|
||||
self.max_seq_len = None
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
self.max_seq_len = max(self.seq_lens)
|
||||
|
@ -76,8 +77,10 @@ class SampleBatch:
|
|||
self.data[k] = np.array(v)
|
||||
if not lengths:
|
||||
raise ValueError("Empty sample batch")
|
||||
assert len(set(lengths)) == 1, \
|
||||
"Data columns must be same length, but lens are {}".format(lengths)
|
||||
if not self.dont_check_lens:
|
||||
assert len(set(lengths)) == 1, \
|
||||
"Data columns must be same length, but lens are " \
|
||||
"{}".format(lengths)
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
self.count = sum(self.seq_lens)
|
||||
else:
|
||||
|
@ -117,7 +120,8 @@ class SampleBatch:
|
|||
return SampleBatch(
|
||||
out,
|
||||
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
||||
_time_major=concat_samples[0].time_major)
|
||||
_time_major=concat_samples[0].time_major,
|
||||
_dont_check_lens=True)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
||||
|
@ -248,12 +252,35 @@ class SampleBatch:
|
|||
SampleBatch: A new SampleBatch, which has a slice of this batch's
|
||||
data.
|
||||
"""
|
||||
if self.time_major is not None:
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
data = {k: v[start:end] for k, v in self.data.items()}
|
||||
# Fix state_in_x data.
|
||||
count = 0
|
||||
state_start = None
|
||||
seq_lens = None
|
||||
for i, seq_len in enumerate(self.seq_lens):
|
||||
count += seq_len
|
||||
if count >= end:
|
||||
state_idx = 0
|
||||
state_key = "state_in_{}".format(state_idx)
|
||||
while state_key in self.data:
|
||||
data[state_key] = self.data[state_key][state_start:i +
|
||||
1]
|
||||
state_idx += 1
|
||||
state_key = "state_in_{}".format(state_idx)
|
||||
seq_lens = list(self.seq_lens[state_start:i]) + [
|
||||
seq_len - (count - end)
|
||||
]
|
||||
assert sum(seq_lens) == (end - start)
|
||||
break
|
||||
elif state_start is None and count > start:
|
||||
state_start = i
|
||||
|
||||
return SampleBatch(
|
||||
{k: v[:, start:end]
|
||||
for k, v in self.data.items()},
|
||||
_seq_lens=self.seq_lens[start:end],
|
||||
_time_major=self.time_major)
|
||||
data,
|
||||
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
||||
_time_major=self.time_major,
|
||||
_dont_check_lens=True)
|
||||
else:
|
||||
return SampleBatch(
|
||||
{k: v[start:end]
|
||||
|
|
|
@ -174,11 +174,6 @@ class TFPolicy(Policy):
|
|||
raise ValueError(
|
||||
"Number of state input and output tensors must match, got: "
|
||||
"{} vs {}".format(self._state_inputs, self._state_outputs))
|
||||
if len(self.get_initial_state()) != len(self._state_inputs):
|
||||
raise ValueError(
|
||||
"Length of initial state must match number of state inputs, "
|
||||
"got: {} vs {}".format(self.get_initial_state(),
|
||||
self._state_inputs))
|
||||
if self._state_inputs and self._seq_lens is None:
|
||||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
@ -263,6 +258,11 @@ class TFPolicy(Policy):
|
|||
(name, tf1.placeholders) needed for calculating the loss.
|
||||
"""
|
||||
self._loss_input_dict = dict(loss_inputs)
|
||||
self._loss_input_dict_no_rnn = {
|
||||
k: v
|
||||
for k, v in self._loss_input_dict.items()
|
||||
if (v not in self._state_inputs and v != self._seq_lens)
|
||||
}
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
|
||||
|
@ -791,11 +791,11 @@ class TFPolicy(Policy):
|
|||
**fetches[LEARNER_STATS_KEY])
|
||||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch, shuffle):
|
||||
def _get_loss_inputs_dict(self, train_batch, shuffle):
|
||||
"""Return a feed dict from a batch.
|
||||
|
||||
Args:
|
||||
batch (SampleBatch): batch of data to derive inputs from
|
||||
train_batch (SampleBatch): batch of data to derive inputs from.
|
||||
shuffle (bool): whether to shuffle batch sequences. Shuffle may
|
||||
be done in-place. This only makes sense if you're further
|
||||
applying minibatch SGD after getting the outputs.
|
||||
|
@ -806,28 +806,30 @@ class TFPolicy(Policy):
|
|||
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
batch,
|
||||
train_batch,
|
||||
shuffle=shuffle,
|
||||
max_seq_len=self._max_seq_len,
|
||||
batch_divisibility_req=self._batch_divisibility_req,
|
||||
feature_keys=[
|
||||
k for k in self._loss_input_dict.keys() if k != "seq_lens"
|
||||
],
|
||||
feature_keys=list(self._loss_input_dict_no_rnn.keys()),
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
batch["is_training"] = True
|
||||
|
||||
# Mark the batch as "is_training" so the Model can use this
|
||||
# information.
|
||||
train_batch["is_training"] = True
|
||||
|
||||
# Build the feed dict from the batch.
|
||||
feed_dict = {}
|
||||
for key, placeholder in self._loss_input_dict.items():
|
||||
feed_dict[placeholder] = batch[key]
|
||||
feed_dict[placeholder] = train_batch[key]
|
||||
|
||||
state_keys = [
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
for key in state_keys:
|
||||
feed_dict[self._loss_input_dict[key]] = batch[key]
|
||||
feed_dict[self._loss_input_dict[key]] = train_batch[key]
|
||||
if state_keys:
|
||||
feed_dict[self._seq_lens] = batch["seq_lens"]
|
||||
feed_dict[self._seq_lens] = train_batch["seq_lens"]
|
||||
|
||||
return feed_dict
|
||||
|
||||
|
|
|
@ -345,12 +345,13 @@ class TorchPolicy(Policy):
|
|||
@DeveloperAPI
|
||||
def compute_gradients(self,
|
||||
postprocessed_batch: SampleBatch) -> ModelGradients:
|
||||
# Get batch ready for RNNs, if applicable.
|
||||
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
max_seq_len=self.max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -29,8 +30,9 @@ class ViewRequirement:
|
|||
def __init__(self,
|
||||
data_col: Optional[str] = None,
|
||||
space: gym.Space = None,
|
||||
shift: Union[int, List[int]] = 0,
|
||||
shift: Union[int, str, List[int]] = 0,
|
||||
index: Optional[int] = None,
|
||||
batch_repeat_value: int = 1,
|
||||
used_for_training: bool = True):
|
||||
"""Initializes a ViewRequirement object.
|
||||
|
||||
|
@ -64,7 +66,19 @@ class ViewRequirement:
|
|||
self.space = space if space is not None else gym.spaces.Box(
|
||||
float("-inf"), float("inf"), shape=())
|
||||
|
||||
self.index = index
|
||||
|
||||
self.shift = shift
|
||||
if isinstance(self.shift, (list, tuple)):
|
||||
self.shift = np.array(self.shift)
|
||||
|
||||
# Special case: Providing a (probably larger) range of indices, e.g.
|
||||
# "-100:0" (past 100 timesteps plus current one).
|
||||
self.shift_from = self.shift_to = None
|
||||
if isinstance(self.shift, str):
|
||||
f, t = self.shift.split(":")
|
||||
self.shift_from = int(f)
|
||||
self.shift_to = int(t)
|
||||
|
||||
self.index = index
|
||||
self.batch_repeat_value = batch_repeat_value
|
||||
|
||||
self.used_for_training = used_for_training
|
||||
|
|
|
@ -44,7 +44,8 @@ class TestAttentionNetLearning(unittest.TestCase):
|
|||
"num_transformer_units": 1,
|
||||
"attn_dim": 32,
|
||||
"num_heads": 1,
|
||||
"memory_tau": 5,
|
||||
"memory_inference": 5,
|
||||
"memory_training": 5,
|
||||
"head_dim": 32,
|
||||
"ff_hidden_dim": 32,
|
||||
},
|
||||
|
@ -71,7 +72,8 @@ class TestAttentionNetLearning(unittest.TestCase):
|
|||
# "num_transformer_units": 1,
|
||||
# "attn_dim": 64,
|
||||
# "num_heads": 1,
|
||||
# "memory_tau": 10,
|
||||
# "memory_inference": 10,
|
||||
# "memory_training": 10,
|
||||
# "head_dim": 32,
|
||||
# "ff_hidden_dim": 32,
|
||||
# },
|
||||
|
|
|
@ -18,9 +18,13 @@ class TestLSTMUtils(unittest.TestCase):
|
|||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
episode_ids=eps_ids,
|
||||
unroll_ids=np.ones_like(eps_ids),
|
||||
agent_indices=agent_ids,
|
||||
feature_columns=f,
|
||||
state_columns=s,
|
||||
max_seq_len=4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [
|
||||
[101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
|
||||
[[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
|
||||
|
@ -35,9 +39,13 @@ class TestLSTMUtils(unittest.TestCase):
|
|||
obs = np.ones((84, 84, 4))
|
||||
f = [[obs, obs * 2, obs * 3]]
|
||||
s = [[209, 208, 207]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
episode_ids=eps_ids,
|
||||
unroll_ids=np.ones_like(eps_ids),
|
||||
agent_indices=agent_ids,
|
||||
feature_columns=f,
|
||||
state_columns=s,
|
||||
max_seq_len=4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [
|
||||
np.array([obs, obs * 2, obs * 3]).tolist(),
|
||||
])
|
||||
|
@ -51,8 +59,13 @@ class TestLSTMUtils(unittest.TestCase):
|
|||
f = [[101, 102, 103, 201, 202, 203, 204, 205],
|
||||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
_, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f,
|
||||
s, 4)
|
||||
_, _, seq_lens = chop_into_sequences(
|
||||
episode_ids=eps_ids,
|
||||
unroll_ids=batch_ids,
|
||||
agent_indices=agent_ids,
|
||||
feature_columns=f,
|
||||
state_columns=s,
|
||||
max_seq_len=4)
|
||||
self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
|
||||
|
||||
def test_multi_agent(self):
|
||||
|
@ -62,12 +75,12 @@ class TestLSTMUtils(unittest.TestCase):
|
|||
[[101], [102], [103], [201], [202], [203], [204], [205]]]
|
||||
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids,
|
||||
f,
|
||||
s,
|
||||
4,
|
||||
episode_ids=eps_ids,
|
||||
unroll_ids=np.ones_like(eps_ids),
|
||||
agent_indices=agent_ids,
|
||||
feature_columns=f,
|
||||
state_columns=s,
|
||||
max_seq_len=4,
|
||||
dynamic_max=False)
|
||||
self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
|
||||
self.assertEqual(len(f_pad[0]), 20)
|
||||
|
@ -78,9 +91,13 @@ class TestLSTMUtils(unittest.TestCase):
|
|||
agent_ids = [2, 2, 2]
|
||||
f = [[1, 1, 1]]
|
||||
s = [[1, 1, 1]]
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
|
||||
np.ones_like(eps_ids),
|
||||
agent_ids, f, s, 4)
|
||||
f_pad, s_init, seq_lens = chop_into_sequences(
|
||||
episode_ids=eps_ids,
|
||||
unroll_ids=np.ones_like(eps_ids),
|
||||
agent_indices=agent_ids,
|
||||
feature_columns=f,
|
||||
state_columns=s,
|
||||
max_seq_len=4)
|
||||
self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
|
||||
self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
|
||||
self.assertEqual(seq_lens.tolist(), [1, 2])
|
||||
|
|
|
@ -72,18 +72,23 @@ def minibatches(samples, sgd_minibatch_size):
|
|||
|
||||
i = 0
|
||||
slices = []
|
||||
if samples.seq_lens:
|
||||
seq_no = 0
|
||||
while i < samples.count:
|
||||
seq_no_end = seq_no
|
||||
actual_count = 0
|
||||
while actual_count < sgd_minibatch_size and len(
|
||||
samples.seq_lens) > seq_no_end:
|
||||
actual_count += samples.seq_lens[seq_no_end]
|
||||
seq_no_end += 1
|
||||
slices.append((seq_no, seq_no_end))
|
||||
i += actual_count
|
||||
seq_no = seq_no_end
|
||||
if samples.seq_lens is not None and len(samples.seq_lens) > 0:
|
||||
start_pos = 0
|
||||
minibatch_size = 0
|
||||
idx = 0
|
||||
while idx < len(samples.seq_lens):
|
||||
seq_len = samples.seq_lens[idx]
|
||||
minibatch_size += seq_len
|
||||
# Complete minibatch -> Append to slices.
|
||||
if minibatch_size >= sgd_minibatch_size:
|
||||
slices.append((start_pos, start_pos + sgd_minibatch_size))
|
||||
start_pos += sgd_minibatch_size
|
||||
if minibatch_size > sgd_minibatch_size:
|
||||
overhead = minibatch_size - sgd_minibatch_size
|
||||
start_pos -= (seq_len - overhead)
|
||||
idx -= 1
|
||||
minibatch_size = 0
|
||||
idx += 1
|
||||
else:
|
||||
while i < samples.count:
|
||||
slices.append((i, i + sgd_minibatch_size))
|
||||
|
|
|
@ -100,6 +100,9 @@ ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
|
|||
# Type of dict returned by get_weights() representing model weights.
|
||||
ModelWeights = dict
|
||||
|
||||
# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls.
|
||||
ModelInputDict = Dict[str, TensorType]
|
||||
|
||||
# Some kind of sample batch.
|
||||
SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue