[RLlib] Attention Nets: tf (#12753)

This commit is contained in:
Sven Mika 2020-12-21 02:22:32 +01:00 committed by GitHub
parent e715ade2d1
commit b2bcab711d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 567 additions and 561 deletions

View file

@ -1074,13 +1074,6 @@ py_test(
# Tag: models
# --------------------------------------------------------------------
py_test(
name = "test_attention_nets",
tags = ["models"],
size = "small",
srcs = ["models/tests/test_attention_nets.py"]
)
py_test(
name = "test_convtranspose2d_stack",
tags = ["models"],

View file

@ -191,8 +191,8 @@ class DefaultCallbacks:
**kwargs) -> None:
"""Called at the beginning of Policy.learn_on_batch().
Note: This is called before the Model's `preprocess_train_batch()`
is called.
Note: This is called before 0-padding via
`pad_batch_to_sequences_of_same_size`.
Args:
policy (Policy): Reference to the current Policy object.

View file

@ -198,7 +198,8 @@ def postprocess_ppo_gae(
# input_dict.
if policy.config["_use_trajectory_view_api"]:
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
input_dict = policy.model.get_input_dict(
sample_batch, index="last")
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:

View file

@ -1,6 +1,7 @@
import collections
from gym.spaces import Space
import logging
import math
import numpy as np
from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union
@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray:
return arr
_INIT_COLS = [SampleBatch.OBS]
class _AgentCollector:
"""Collects samples for one agent in one trajectory (episode).
@ -45,9 +49,18 @@ class _AgentCollector:
_next_unroll_id = 0 # disambiguates unrolls within a single episode
def __init__(self, shift_before: int = 0):
self.shift_before = max(shift_before, 1)
def __init__(self, view_reqs):
# Determine the size of the buffer we need for data before the actual
# episode starts. This is used for 0-buffering of e.g. prev-actions,
# or internal state inputs.
self.shift_before = -min(
(int(vr.shift.split(":")[0])
if isinstance(vr.shift, str) else vr.shift) +
(-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0)
for k, vr in view_reqs.items())
# The actual data buffers (lists holding each timestep's data).
self.buffers: Dict[str, List] = {}
# The episode ID for the agent for which we collect data.
self.episode_id = None
# The simple timestep count for this agent. Gets increased by one
# each time a (non-initial!) observation is added.
@ -137,31 +150,88 @@ class _AgentCollector:
# -> skip.
if data_col not in self.buffers:
continue
# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
shift = view_req.shift - \
(1 if data_col == SampleBatch.OBS else 0)
obs_shift = -1 if data_col == SampleBatch.OBS else 0
# Keep an np-array cache so we don't have to regenerate the
# np-array for different view_cols using to the same data_col.
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
# Shift is exactly 0: Send trajectory as is.
if shift == 0:
data = np_data[data_col][self.shift_before:]
# Shift is positive: We still need to 0-pad at the end here.
elif shift > 0:
data = to_float_np_array(
self.buffers[data_col][self.shift_before + shift:] + [
np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype) for _ in range(shift)
# Range of indices on time-axis, e.g. "-50:-1". Together with
# the `batch_repeat_value`, this determines the data produced.
# Example:
# batch_repeat_value=10, shift_from=-3, shift_to=-1
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting data=[[-3, -2, -1], [7, 8, 9]]
# Range of 3 consecutive items repeats every 10 timesteps.
if view_req.shift_from is not None:
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
/ view_req.batch_repeat_value))
data = np.asarray([
np_data[data_col][self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_from +
obs_shift:self.shift_before +
(i * view_req.batch_repeat_value) +
view_req.shift_to + 1 + obs_shift]
for i in range(count)
])
# Shift is negative: Shift into the already existing and 0-padded
# "before" area of our buffers.
else:
data = np_data[data_col][self.shift_before +
view_req.shift_from +
obs_shift:self.shift_before +
view_req.shift_to + 1 + obs_shift]
# Set of (probably non-consecutive) indices.
# Example:
# shift=[-3, 0]
# buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...]
elif isinstance(view_req.shift, np.ndarray):
data = np_data[data_col][self.shift_before + obs_shift +
view_req.shift]
# Single shift int value. Use the trajectory as-is, and if
# `shift` != 0: shifted by that value.
else:
data = np_data[data_col][self.shift_before + shift:shift]
shift = view_req.shift + obs_shift
# Batch repeat (only provide a value every n timesteps).
if view_req.batch_repeat_value > 1:
count = int(
math.ceil((len(np_data[data_col]) - self.shift_before)
/ view_req.batch_repeat_value))
data = np.asarray([
np_data[data_col][self.shift_before + (
i * view_req.batch_repeat_value) + shift]
for i in range(count)
])
# Shift is exactly 0: Use trajectory as is.
elif shift == 0:
data = np_data[data_col][self.shift_before:]
# Shift is positive: We still need to 0-pad at the end.
elif shift > 0:
data = to_float_np_array(
self.buffers[data_col][self.shift_before + shift:] + [
np.zeros(
shape=view_req.space.shape,
dtype=view_req.space.dtype)
for _ in range(shift)
])
# Shift is negative: Shift into the already existing and
# 0-padded "before" area of our buffers.
else:
data = np_data[data_col][self.shift_before + shift:shift]
if len(data) > 0:
batch_data[view_col] = data
batch = SampleBatch(batch_data)
# Due to possible batch-repeats > 1, columns in the resulting batch
# may not all have the same batch size.
batch = SampleBatch(batch_data, _dont_check_lens=True)
# Add EPS_ID and UNROLL_ID to batch.
batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id,
@ -230,15 +300,22 @@ class _PolicyCollector:
appended to this policy's buffers.
"""
def __init__(self):
"""Initializes a _PolicyCollector instance."""
def __init__(self, policy):
"""Initializes a _PolicyCollector instance.
Args:
policy (Policy): The policy object.
"""
self.buffers: Dict[str, List] = collections.defaultdict(list)
self.policy = policy
# The total timestep count for all agents that use this policy.
# NOTE: This is not an env-step count (across n agents). AgentA and
# agentB, both using this policy, acting in the same episode and both
# doing n steps would increase the count by 2*n.
self.agent_steps = 0
# Seq-lens list of already added agent batches.
self.seq_lens = [] if policy.is_recurrent() else None
def add_postprocessed_batch_for_training(
self, batch: SampleBatch,
@ -257,11 +334,18 @@ class _PolicyCollector:
# 1) If col is not in view_requirements, we must have a direct
# child of the base Policy that doesn't do auto-view req creation.
# 2) Col is in view-reqs and needed for training.
if view_col not in view_requirements or \
view_requirements[view_col].used_for_training:
view_req = view_requirements.get(view_col)
if view_req is None or view_req.used_for_training:
self.buffers[view_col].extend(data)
# Add the agent's trajectory length to our count.
self.agent_steps += batch.count
# Adjust the seq-lens array depending on the incoming agent sequences.
if self.seq_lens is not None:
max_seq_len = self.policy.config["model"]["max_seq_len"]
count = batch.count
while count > 0:
self.seq_lens.append(min(count, max_seq_len))
count -= max_seq_len
def build(self):
"""Builds a SampleBatch for this policy from the collected data.
@ -273,20 +357,22 @@ class _PolicyCollector:
this policy.
"""
# Create batch from our buffers.
batch = SampleBatch(self.buffers)
assert SampleBatch.UNROLL_ID in batch.data
batch = SampleBatch(
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
# Clear buffers for future samples.
self.buffers.clear()
# Reset agent steps to 0.
# Reset agent steps to 0 and seq-lens to empty list.
self.agent_steps = 0
if self.seq_lens is not None:
self.seq_lens = []
return batch
class _PolicyCollectorGroup:
def __init__(self, policy_map):
self.policy_collectors = {
pid: _PolicyCollector()
for pid in policy_map.keys()
pid: _PolicyCollector(policy)
for pid, policy in policy_map.items()
}
# Total env-steps (1 env-step=up to N agents stepped).
self.env_steps = 0
@ -396,11 +482,14 @@ class _SimpleListCollector(_SampleCollector):
self.agent_key_to_policy_id[agent_key] = policy_id
else:
assert self.agent_key_to_policy_id[agent_key] == policy_id
policy = self.policy_map[policy_id]
view_reqs = policy.model.inference_view_requirements if \
getattr(policy, "model", None) else policy.view_requirements
# Add initial obs to Trajectory.
assert agent_key not in self.agent_collectors
# TODO: determine exact shift-before based on the view-req shifts.
self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key] = _AgentCollector(view_reqs)
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_index=episode._agent_index(agent_id),
@ -466,11 +555,19 @@ class _SimpleListCollector(_SampleCollector):
for view_col, view_req in view_reqs.items():
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.shift - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.AGENT_INDEX] else 0)
delta = -1 if data_col in [
SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX
] else 0
# Range of shifts, e.g. "-100:0". Note: This includes index 0!
if view_req.shift_from is not None:
time_indices = (view_req.shift_from + delta,
view_req.shift_to + delta)
# Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0].
else:
time_indices = view_req.shift + delta
data_list = []
# Loop through agents and add-up their data (batch).
for k in keys:
if data_col == SampleBatch.EPS_ID:
data_list.append(self.agent_collectors[k].episode_id)
@ -482,7 +579,15 @@ class _SimpleListCollector(_SampleCollector):
self.agent_collectors[k]._build_buffers({
data_col: fill_value
})
data_list.append(buffers[k][data_col][time_indices])
if isinstance(time_indices, tuple):
if time_indices[1] == -1:
data_list.append(
buffers[k][data_col][time_indices[0]:])
else:
data_list.append(buffers[k][data_col][time_indices[
0]:time_indices[1] + 1])
else:
data_list.append(buffers[k][data_col][time_indices])
input_dict[view_col] = np.array(data_list)
self._reset_inference_calls(policy_id)

View file

@ -50,8 +50,6 @@ def compute_advantages(rollout: SampleBatch,
processed rewards.
"""
rollout_size = len(rollout[SampleBatch.ACTIONS])
assert SampleBatch.VF_PREDS in rollout or not use_critic, \
"use_critic=True but values not found"
assert use_critic or not use_gae, \
@ -90,6 +88,4 @@ def compute_advantages(rollout: SampleBatch,
rollout[Postprocessing.ADVANTAGES] = rollout[
Postprocessing.ADVANTAGES].astype(np.float32)
assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \
"Rollout stacked incorrectly!"
return rollout

View file

@ -39,6 +39,7 @@ if __name__ == "__main__":
config = {
"env": args.env,
# This env_config is only used for the RepeatAfterMeEnv env.
"env_config": {
"repeat_delay": 2,
},
@ -48,7 +49,7 @@ if __name__ == "__main__":
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"num_sgd_iter": 10,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": GTrXLNet,
@ -56,9 +57,10 @@ if __name__ == "__main__":
"custom_model_config": {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 2,
"memory_tau": 50,
"memory_inference": 100,
"memory_training": 50,
"head_dim": 32,
"num_heads": 2,
"ff_hidden_dim": 32,
},
},
@ -71,7 +73,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=1)
results = tune.run(args.run, config=config, stop=stop, verbose=2)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -59,7 +59,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, config=config, stop=stop, verbose=1)
results = tune.run(args.run, config=config, stop=stop, verbose=2)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -13,7 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.utils.typing import ModelConfigDict, TensorStructType
from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, \
TensorStructType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -238,14 +239,14 @@ class ModelV2:
right input dict, state, and seq len arguments.
"""
train_batch["is_training"] = is_training
input_dict = train_batch.copy()
input_dict["is_training"] = is_training
states = []
i = 0
while "state_in_{}".format(i) in train_batch:
states.append(train_batch["state_in_{}".format(i)])
while "state_in_{}".format(i) in input_dict:
states.append(input_dict["state_in_{}".format(i)])
i += 1
ret = self.__call__(train_batch, states, train_batch.get("seq_lens"))
del train_batch["is_training"]
ret = self.__call__(input_dict, states, input_dict.get("seq_lens"))
return ret
def import_from_h5(self, h5_file: str) -> None:
@ -316,21 +317,57 @@ class ModelV2:
# TODO: (sven) Experimental method.
def get_input_dict(self, sample_batch,
index: int = -1) -> Dict[str, TensorType]:
if index < 0:
index = sample_batch.count - 1
index: Union[int, str] = "last") -> ModelInputDict:
"""Creates single ts input-dict at given index from a SampleBatch.
Args:
sample_batch (SampleBatch): A single-trajectory SampleBatch object
to generate the compute_actions input dict from.
index (Union[int, str]): An integer index value indicating the
position in the trajectory for which to generate the
compute_actions input dict. Set to "last" to generate the dict
at the very end of the trajectory (e.g. for value estimation).
Note that "last" is different from -1, as "last" will use the
final NEXT_OBS as observation input.
Returns:
ModelInputDict: The (single-timestep) input dict for ModelV2 calls.
"""
last_mappings = {
SampleBatch.OBS: SampleBatch.NEXT_OBS,
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
}
input_dict = {}
for view_col, view_req in self.inference_view_requirements.items():
# Create batches of size 1 (single-agent input-dict).
# Index range.
if isinstance(index, tuple):
data = sample_batch[view_col][index[0]:index[1] + 1]
input_dict[view_col] = np.array([data])
# Single index.
data_col = view_req.data_col or view_col
if index == "last":
data_col = last_mappings.get(data_col, data_col)
if view_req.shift_from is not None:
data = sample_batch[view_col][-1]
traj_len = len(sample_batch[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
input_dict[view_col] = np.array([
np.concatenate([
data, sample_batch[data_col][-missing_at_end:]
])[view_req.shift_from:view_req.shift_to +
1 if view_req.shift_to != -1 else None]
])
else:
data = sample_batch[data_col][-1]
input_dict[view_col] = np.array([data])
else:
input_dict[view_col] = sample_batch[view_col][index:index + 1]
# Index range.
if isinstance(index, tuple):
data = sample_batch[data_col][index[0]:index[1] + 1
if index[1] != -1 else None]
input_dict[view_col] = np.array([data])
# Single index.
else:
input_dict[view_col] = sample_batch[data_col][
index:index + 1 if index != -1 else None]
# Add valid `seq_lens`, just in case RNNs need it.
input_dict["seq_lens"] = np.array([1], dtype=np.int32)

View file

@ -1,263 +0,0 @@
import gym
import numpy as np
import unittest
from ray.rllib.models.tf.attention_net import relative_position_embedding, \
GTrXLNet
from ray.rllib.models.tf.layers import MultiHeadAttention
from ray.rllib.models.torch.attention_net import relative_position_embedding \
as relative_position_embedding_torch, GTrXLNet as TorchGTrXLNet
from ray.rllib.models.torch.modules.multi_head_attention import \
MultiHeadAttention as TorchMultiHeadAttention
from ray.rllib.utils.framework import try_import_torch, try_import_tf
from ray.rllib.utils.test_utils import framework_iterator
torch, nn = try_import_torch()
tf1, tf, tfv = try_import_tf()
class TestAttentionNets(unittest.TestCase):
"""Tests various torch/modules and tf/layers required for AttentionNet"""
def train_torch_full_model(self,
model,
inputs,
outputs,
num_epochs=250,
state=None,
seq_lens=None):
"""Convenience method that trains a Torch model for num_epochs epochs
and tests whether loss decreased, as expected.
Args:
model (nn.Module): Torch model to be trained.
inputs (torch.Tensor): Training data
outputs (torch.Tensor): Training labels
num_epochs (int): Number of epochs to train for
state (torch.Tensor): Internal state of module
seq_lens (torch.Tensor): Tensor of sequence lengths
"""
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# Check that the layer trains correctly
for t in range(num_epochs):
y_pred = model(inputs, state, seq_lens)
loss = criterion(y_pred[0], torch.squeeze(outputs[0]))
if t % 10 == 1:
print(t, loss.item())
# if t == 0:
# init_loss = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# final_loss = loss.item()
# The final loss has decreased, which tests
# that the model is learning from the training data.
# self.assertLess(final_loss / init_loss, 0.99)
def train_torch_layer(self, model, inputs, outputs, num_epochs=250):
"""Convenience method that trains a Torch model for num_epochs epochs
and tests whether loss decreased, as expected.
Args:
model (nn.Module): Torch model to be trained.
inputs (torch.Tensor): Training data
outputs (torch.Tensor): Training labels
num_epochs (int): Number of epochs to train for
"""
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# Check that the layer trains correctly
for t in range(num_epochs):
y_pred = model(inputs)
loss = criterion(y_pred, outputs)
if t == 1:
init_loss = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
final_loss = loss.item()
# The final loss has decreased by a factor of 2, which tests
# that the model is learning from the training data.
self.assertLess(final_loss / init_loss, 0.5)
def train_tf_model(self,
model,
inputs,
outputs,
num_epochs=250,
minibatch_size=32):
"""Convenience method that trains a Tensorflow model for num_epochs
epochs and tests whether loss decreased, as expected.
Args:
model (tf.Model): Torch model to be trained.
inputs (np.array): Training data
outputs (np.array): Training labels
num_epochs (int): Number of training epochs
batch_size (int): Number of samples in each minibatch
"""
# Configure a model for mean-squared error loss.
model.compile(optimizer="SGD", loss="mse", metrics=["mae"])
hist = model.fit(
inputs,
outputs,
verbose=0,
epochs=num_epochs,
batch_size=minibatch_size).history
init_loss = hist["loss"][0]
final_loss = hist["loss"][-1]
self.assertLess(final_loss / init_loss, 0.5)
def test_multi_head_attention(self):
"""Tests the MultiHeadAttention mechanism of Vaswani et al."""
# B is batch size
B = 1
# D_in is attention dim, L is memory_tau
L, D_in, D_out = 2, 32, 10
for fw, sess in framework_iterator(
frameworks=("tfe", "torch", "tf"), session=True):
# Create a single attention layer with 2 heads.
if fw == "torch":
# Create random Tensors to hold inputs and outputs
x = torch.randn(B, L, D_in)
y = torch.randn(B, L, D_out)
model = TorchMultiHeadAttention(
in_dim=D_in, out_dim=D_out, num_heads=2, head_dim=32)
self.train_torch_layer(model, x, y, num_epochs=500)
# Framework is tensorflow or tensorflow-eager.
else:
x = np.random.random((B, L, D_in))
y = np.random.random((B, L, D_out))
inputs = tf.keras.layers.Input(shape=(L, D_in))
model = tf.keras.Sequential([
inputs,
MultiHeadAttention(
out_dim=D_out, num_heads=2, head_dim=32)
])
self.train_tf_model(model, x, y)
def test_attention_net(self):
"""Tests the GTrXL.
Builds a full AttentionNet and checks that it trains in a supervised
setting."""
# Checks that torch and tf embedding matrices are the same
with tf1.Session().as_default() as sess:
assert np.allclose(
relative_position_embedding(20, 15).eval(session=sess),
relative_position_embedding_torch(20, 15).numpy())
# B is batch size
B = 32
# D_in is attention dim, L is memory_tau
L, D_in, D_out = 2, 16, 2
for fw, sess in framework_iterator(session=True):
# Create a single attention layer with 2 heads
if fw == "torch":
# Create random Tensors to hold inputs and outputs
x = torch.randn(B, L, D_in)
y = torch.randn(B, L, D_out)
value_labels = torch.randn(B, L, D_in)
memory_labels = torch.randn(B, L, D_out)
attention_net = TorchGTrXLNet(
observation_space=gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=(D_in, )),
action_space=gym.spaces.Discrete(D_out),
num_outputs=D_out,
model_config={"max_seq_len": 2},
name="TestTorchAttentionNet",
num_transformer_units=2,
attn_dim=D_in,
num_heads=2,
memory_tau=L,
head_dim=D_out,
ff_hidden_dim=16,
init_gate_bias=2.0)
init_state = attention_net.get_initial_state()
# Get initial state and add a batch dimension.
init_state = [np.expand_dims(s, 0) for s in init_state]
seq_lens_init = torch.full(
size=(B, ), fill_value=L, dtype=torch.int32)
# Torch implementation expects a formatted input_dict instead
# of a numpy array as input.
input_dict = {"obs": x}
self.train_torch_full_model(
attention_net,
input_dict, [y, value_labels, memory_labels],
num_epochs=250,
state=init_state,
seq_lens=seq_lens_init)
# Framework is tensorflow or tensorflow-eager.
else:
x = np.random.random((B, L, D_in))
y = np.random.random((B, L, D_out))
value_labels = np.random.random((B, L, 1))
memory_labels = np.random.random((B, L, D_in))
# We need to create (N-1) MLP labels for N transformer units
mlp_labels = np.random.random((B, L, D_in))
attention_net = GTrXLNet(
observation_space=gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=(D_in, )),
action_space=gym.spaces.Discrete(D_out),
num_outputs=D_out,
model_config={"max_seq_len": 2},
name="TestTFAttentionNet",
num_transformer_units=2,
attn_dim=D_in,
num_heads=2,
memory_tau=L,
head_dim=D_out,
ff_hidden_dim=16,
init_gate_bias=2.0)
model = attention_net.trxl_model
# Get initial state and add a batch dimension.
init_state = attention_net.get_initial_state()
init_state = [np.tile(s, (B, 1, 1)) for s in init_state]
self.train_tf_model(
model, [x] + init_state,
[y, value_labels, memory_labels, mlp_labels],
num_epochs=200,
minibatch_size=B)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -8,14 +8,17 @@
Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
https://www.aclweb.org/anthology/P19-1285.pdf
"""
from gym.spaces import Box
import numpy as np
import gym
from typing import Optional, Any
from typing import Any, Optional
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
SkipConnection
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
@ -60,7 +63,7 @@ class TrXLNet(RecurrentNetwork):
model_config: ModelConfigDict, name: str,
num_transformer_units: int, attn_dim: int, num_heads: int,
head_dim: int, ff_hidden_dim: int):
"""Initializes a TfXLNet object.
"""Initializes a TrXLNet object.
Args:
num_transformer_units (int): The number of Transformer repeats to
@ -88,8 +91,6 @@ class TrXLNet(RecurrentNetwork):
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim)
inputs = tf.keras.layers.Input(
shape=(self.max_seq_len, self.obs_dim), name="inputs")
E_out = tf.keras.layers.Dense(attn_dim)(inputs)
@ -100,7 +101,6 @@ class TrXLNet(RecurrentNetwork):
out_dim=attn_dim,
num_heads=num_heads,
head_dim=head_dim,
rel_pos_encoder=pos_embedding,
input_layernorm=False,
output_activation=None),
fan_in_layer=None)(E_out)
@ -160,7 +160,8 @@ class GTrXLNet(RecurrentNetwork):
>> num_transformer_units=1,
>> attn_dim=32,
>> num_heads=2,
>> memory_tau=50,
>> memory_inference=100,
>> memory_training=50,
>> etc..
>> }
"""
@ -174,11 +175,12 @@ class GTrXLNet(RecurrentNetwork):
num_transformer_units: int,
attn_dim: int,
num_heads: int,
memory_tau: int,
memory_inference: int,
memory_training: int,
head_dim: int,
ff_hidden_dim: int,
init_gate_bias: float = 2.0):
"""Initializes a GTrXLNet.
"""Initializes a GTrXLNet instance.
Args:
num_transformer_units (int): The number of Transformer repeats to
@ -187,9 +189,15 @@ class GTrXLNet(RecurrentNetwork):
unit.
num_heads (int): The number of attention heads to use in parallel.
Denoted as `H` in [3].
memory_tau (int): The number of timesteps to store in each
transformer block's memory M (concat'd over time and fed into
next transformer block as input).
memory_inference (int): The number of timesteps to concat (time
axis) and feed into the next transformer unit as inference
input. The first transformer unit will receive this number of
past observations (plus the current one), instead.
memory_training (int): The number of timesteps to concat (time
axis) and feed into the next transformer unit as training
input (plus the actual input sequence of len=max_seq_len).
The first transformer unit will receive this number of
past observations (plus the input sequence), instead.
head_dim (int): The dimension of a single(!) head.
Denoted as `d` in [3].
ff_hidden_dim (int): The dimension of the hidden layer within
@ -208,21 +216,18 @@ class GTrXLNet(RecurrentNetwork):
self.num_transformer_units = num_transformer_units
self.attn_dim = attn_dim
self.num_heads = num_heads
self.memory_tau = memory_tau
self.memory_inference = memory_inference
self.memory_training = memory_training
self.head_dim = head_dim
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = observation_space.shape[0]
# Constant (non-trainable) sinusoid rel pos encoding matrix.
Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
self.attn_dim)
# Raw observation input.
# Raw observation input (plus (None) time axis).
input_layer = tf.keras.layers.Input(
shape=(self.max_seq_len, self.obs_dim), name="inputs")
shape=(None, self.obs_dim), name="inputs")
memory_ins = [
tf.keras.layers.Input(
shape=(self.memory_tau, self.attn_dim),
shape=(None, self.attn_dim),
dtype=tf.float32,
name="memory_in_{}".format(i))
for i in range(self.num_transformer_units)
@ -242,7 +247,6 @@ class GTrXLNet(RecurrentNetwork):
out_dim=self.attn_dim,
num_heads=num_heads,
head_dim=head_dim,
rel_pos_encoder=Phi,
input_layernorm=True,
output_activation=tf.nn.relu),
fan_in_layer=GRUGate(init_gate_bias),
@ -280,69 +284,52 @@ class GTrXLNet(RecurrentNetwork):
self.register_variables(self.trxl_model.variables)
self.trxl_model.summary()
@override(RecurrentNetwork)
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
# To make Attention work with current RLlib's ModelV2 API:
# We assume `state` is the history of L recent observations (all
# concatenated into one tensor) and append the current inputs to the
# end and only keep the most recent (up to `max_seq_len`). This allows
# us to deal with timestep-wise inference and full sequence training
# within the same logic.
observations = state[0]
memory = state[1:]
# Setup inference view (`memory-inference` x past observations +
# current one (0))
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
for i in range(self.num_transformer_units):
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
self.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift="-{}:-1".format(self.memory_inference),
# Repeat the incoming state every max-seq-len times.
batch_repeat_value=self.max_seq_len,
space=space)
self.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(
space=space,
used_for_training=False)
observations = tf.concat(
(observations, inputs), axis=1)[:, -self.max_seq_len:]
all_out = self.trxl_model([observations] + memory)
logits, self._value_out = all_out[0], all_out[1]
@override(ModelV2)
def forward(self, input_dict, state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
assert seq_lens is not None
# Add the time dim to observations.
B = tf.shape(seq_lens)[0]
observations = input_dict[SampleBatch.OBS]
shape = tf.shape(observations)
T = shape[0] // B
observations = tf.reshape(observations,
tf.concat([[-1, T], shape[1:]], axis=0))
all_out = self.trxl_model([observations] + state)
logits = all_out[0]
self._value_out = all_out[1]
memory_outs = all_out[2:]
# If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
if self.memory_tau > self.max_seq_len:
memory_outs = [
tf.concat(
[memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
axis=1) for i, m in enumerate(memory_outs)
]
else:
memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]
T = tf.shape(inputs)[1] # Length of input segment (time).
logits = logits[:, -T:]
self._value_out = self._value_out[:, -T:]
return logits, [observations] + memory_outs
return tf.reshape(logits, [-1, self.num_outputs]), [
tf.reshape(m, [-1, self.attn_dim]) for m in memory_outs
]
# TODO: (sven) Deprecate this once trajectory view API has fully matured.
@override(RecurrentNetwork)
def get_initial_state(self) -> List[np.ndarray]:
# State is the T last observations concat'd together into one Tensor.
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
# tau timesteps).
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \
[np.zeros((self.memory_tau, self.attn_dim), np.float32)
for _ in range(self.num_transformer_units)]
return []
@override(ModelV2)
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])
def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType:
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
matrix.
Args:
seq_length (int): The max. sequence length (time axis).
out_dim (int): The number of nodes to go into the first Tranformer
layer with.
Returns:
tf.Tensor: The encoding matrix Phi.
"""
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
pos_offsets = tf.range(seq_length - 1., -1., -1.)
inputs = pos_offsets[:, None] * inverse_freq[None, :]
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)

View file

@ -1,11 +1,11 @@
from ray.rllib.models.tf.layers.gru_gate import GRUGate
from ray.rllib.models.tf.layers.noisy_layer import NoisyLayer
from ray.rllib.models.tf.layers.relative_multi_head_attention import \
RelativeMultiHeadAttention
PositionalEmbedding, RelativeMultiHeadAttention
from ray.rllib.models.tf.layers.skip_connection import SkipConnection
from ray.rllib.models.tf.layers.multi_head_attention import MultiHeadAttention
__all__ = [
"GRUGate", "MultiHeadAttention", "NoisyLayer",
"GRUGate", "MultiHeadAttention", "NoisyLayer", "PositionalEmbedding",
"RelativeMultiHeadAttention", "SkipConnection"
]

View file

@ -1,4 +1,4 @@
from typing import Optional, Any
from typing import Optional
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
@ -16,9 +16,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
out_dim: int,
num_heads: int,
head_dim: int,
rel_pos_encoder: Any,
input_layernorm: bool = False,
output_activation: Optional[Any] = None,
output_activation: Optional["tf.nn.activation"] = None,
**kwargs):
"""Initializes a RelativeMultiHeadAttention keras Layer object.
@ -28,7 +27,6 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
Denoted `H` in [2].
head_dim (int): The dimension of a single(!) attention head
Denoted `D` in [2].
rel_pos_encoder (:
input_layernorm (bool): Whether to prepend a LayerNorm before
everything else. Should be True for building a GTrXL.
output_activation (Optional[tf.nn.activation]): Optional tf.nn
@ -50,9 +48,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
self._uvar = self.add_weight(shape=(num_heads, head_dim))
self._vvar = self.add_weight(shape=(num_heads, head_dim))
# Constant (non-trainable) sinusoid rel pos encoding matrix, which
# depends on this incoming time dimension.
# For inference, we prepend the memory to the current timestep's
# input: Tau + 1. For training, we prepend the memory to the input
# sequence: Tau + T.
self._pos_embedding = PositionalEmbedding(out_dim)
self._pos_proj = tf.keras.layers.Dense(
num_heads * head_dim, use_bias=False)
self._rel_pos_encoder = rel_pos_encoder
self._input_layernorm = None
if input_layernorm:
@ -66,9 +69,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
# Add previous memory chunk (as const, w/o gradient) to input.
# Tau (number of (prev) time slices in each memory chunk).
Tau = memory.shape.as_list()[1] if memory is not None else 0
if memory is not None:
inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1)
Tau = tf.shape(memory)[1]
inputs = tf.concat([tf.stop_gradient(memory), inputs], axis=1)
# Apply the Layer-Norm.
if self._input_layernorm is not None:
@ -77,15 +79,17 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
# Cut out Tau memory timesteps from query.
# Cut out memory timesteps from query.
queries = queries[:, -T:]
# Splitting up queries into per-head dims (d).
queries = tf.reshape(queries, [-1, T, H, d])
keys = tf.reshape(keys, [-1, T + Tau, H, d])
values = tf.reshape(values, [-1, T + Tau, H, d])
keys = tf.reshape(keys, [-1, Tau + T, H, d])
values = tf.reshape(values, [-1, Tau + T, H, d])
R = self._pos_proj(self._rel_pos_encoder)
R = tf.reshape(R, [T + Tau, H, d])
R = self._pos_embedding(Tau + T)
R = self._pos_proj(R)
R = tf.reshape(R, [Tau + T, H, d])
# b=batch
# i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
@ -96,9 +100,9 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
score = score + self.rel_shift(pos_score)
score = score / d**0.5
# causal mask of the same length as the sequence
# Causal mask of the same length as the sequence.
mask = tf.sequence_mask(
tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype)
tf.range(Tau + 1, Tau + T + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
@ -121,3 +125,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object):
x = tf.reshape(x, x_size)
return x
class PositionalEmbedding(tf.keras.layers.Layer if tf else object):
def __init__(self, out_dim, **kwargs):
super().__init__(**kwargs)
self.inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
def call(self, seq_length):
pos_offsets = tf.cast(tf.range(seq_length - 1, -1, -1), tf.float32)
inputs = pos_offsets[:, None] * self.inverse_freq[None, :]
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)

View file

@ -16,7 +16,6 @@ class SkipConnection(tf.keras.layers.Layer if tf else object):
def __init__(self,
layer: Any,
fan_in_layer: Optional[Any] = None,
add_memory: bool = False,
**kwargs):
"""Initializes a SkipConnection keras layer object.

View file

@ -15,7 +15,6 @@ class SkipConnection(nn.Module):
def __init__(self,
layer: nn.Module,
fan_in_layer: Optional[nn.Module] = None,
add_memory: bool = False,
**kwargs):
"""Initializes a SkipConnection nn Module object.

View file

@ -183,11 +183,12 @@ class DynamicTFPolicy(TFPolicy):
else:
if self.config["_use_trajectory_view_api"]:
self._state_inputs = [
tf1.placeholder(
shape=(None, ) + vr.space.shape, dtype=vr.space.dtype)
for k, vr in
get_placeholder(
space=vr.space,
time_axis=not isinstance(vr.shift, int),
) for k, vr in
self.model.inference_view_requirements.items()
if k[:9] == "state_in_"
if k.startswith("state_in_")
]
else:
self._state_inputs = [
@ -423,9 +424,14 @@ class DynamicTFPolicy(TFPolicy):
input_dict[view_col] = existing_inputs[view_col]
# All others.
else:
time_axis = not isinstance(view_req.shift, int)
if view_req.used_for_training:
# Create a +time-axis placeholder if the shift is not an
# int (range or list of ints).
input_dict[view_col] = get_placeholder(
space=view_req.space, name=view_col)
space=view_req.space,
name=view_col,
time_axis=time_axis)
dummy_batch = self._get_dummy_batch_from_view_requirements(
batch_size=32)
@ -490,10 +496,10 @@ class DynamicTFPolicy(TFPolicy):
dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
for k, v in self.extra_compute_action_fetches().items():
dummy_batch[k] = fake_array(v)
dummy_batch = SampleBatch(dummy_batch)
sb = SampleBatch(dummy_batch)
batch_for_postproc = UsageTrackingDict(sb)
batch_for_postproc.count = sb.count
batch_for_postproc = UsageTrackingDict(dummy_batch)
batch_for_postproc.count = dummy_batch.count
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
self.exploration.postprocess_trajectory(self, batch_for_postproc,
self._sess)
@ -519,6 +525,7 @@ class DynamicTFPolicy(TFPolicy):
train_batch.update({
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
for k, v in postprocessed_batch.items():
@ -578,7 +585,8 @@ class DynamicTFPolicy(TFPolicy):
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \
key not in self.model.inference_view_requirements:
self.view_requirements[key].used_for_training = False
if key in self.view_requirements:
self.view_requirements[key].used_for_training = False
if key in self._loss_input_dict:
del self._loss_input_dict[key]
# Remove those not needed at all (leave those that are needed

View file

@ -314,12 +314,16 @@ def build_eager_tf_policy(name,
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch)
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req)
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
self._is_training = True
postprocessed_batch["is_training"] = True
return self._learn_on_batch_eager(postprocessed_batch)
@convert_eager_inputs
@ -332,12 +336,14 @@ def build_eager_tf_policy(name,
@override(Policy)
def compute_gradients(self, samples):
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
samples,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req)
self._is_training = True
samples["is_training"] = True
return self._compute_gradients_eager(samples)
@convert_eager_inputs
@ -369,7 +375,7 @@ def build_eager_tf_policy(name,
# TODO: remove python side effect to cull sources of bugs.
self._is_training = False
self._state_in = state_batches
self._state_in = state_batches or []
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
@ -591,8 +597,6 @@ def build_eager_tf_policy(name,
def _compute_gradients(self, samples):
"""Computes and returns grads as eager tensors."""
self._is_training = True
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
loss = loss_fn(self, self.model, self.dist_class, samples)

View file

@ -629,10 +629,9 @@ class Policy(metaclass=ABCMeta):
batch_for_postproc.count = self._dummy_batch.count
self.exploration.postprocess_trajectory(self, batch_for_postproc)
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
seq_lens = None
if state_outs:
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
# TODO: (sven) This hack will not work for attention net traj.
# view setup.
i = 0
while "state_in_{}".format(i) in postprocessed_batch:
postprocessed_batch["state_in_{}".format(i)] = \
@ -642,12 +641,11 @@ class Policy(metaclass=ABCMeta):
postprocessed_batch["state_out_{}".format(i)][:B]
i += 1
seq_len = sample_batch_size // B
postprocessed_batch["seq_lens"] = \
np.array([seq_len for _ in range(B)], dtype=np.int32)
# Remove the UsageTrackingDict wrap to prep for wrapping the
# train batch with a to-tensor UsageTrackingDict.
train_batch = {k: v for k, v in postprocessed_batch.items()}
train_batch = self._lazy_tensor_dict(train_batch)
seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32)
# Wrap `train_batch` with a to-tensor UsageTrackingDict.
train_batch = self._lazy_tensor_dict(postprocessed_batch)
if seq_lens is not None:
train_batch["seq_lens"] = seq_lens
train_batch.count = self._dummy_batch.count
# Call the loss function, if it exists.
if self._loss is not None:
@ -712,13 +710,33 @@ class Policy(metaclass=ABCMeta):
ret[view_col] = \
np.zeros((batch_size, ) + shape[1:], np.float32)
else:
if isinstance(view_req.space, gym.spaces.Space):
ret[view_col] = np.zeros_like(
[view_req.space.sample() for _ in range(batch_size)])
# Range of indices on time-axis, e.g. "-50:-1".
if view_req.shift_from is not None:
ret[view_col] = np.zeros_like([[
view_req.space.sample()
for _ in range(view_req.shift_to -
view_req.shift_from + 1)
] for _ in range(batch_size)])
# Set of (probably non-consecutive) indices.
elif isinstance(view_req.shift, (list, tuple)):
ret[view_col] = np.zeros_like([[
view_req.space.sample()
for t in range(len(view_req.shift))
] for _ in range(batch_size)])
# Single shift int value.
else:
ret[view_col] = [view_req.space for _ in range(batch_size)]
if isinstance(view_req.space, gym.spaces.Space):
ret[view_col] = np.zeros_like([
view_req.space.sample() for _ in range(batch_size)
])
else:
ret[view_col] = [
view_req.space for _ in range(batch_size)
]
return SampleBatch(ret)
# Due to different view requirements for the different columns,
# columns in the resulting batch may not all have the same batch size.
return SampleBatch(ret, _dont_check_lens=True)
def _update_model_inference_view_requirements_from_init_state(self):
"""Uses Model's (or this Policy's) init state to add needed ViewReqs.
@ -737,8 +755,13 @@ class Policy(metaclass=ABCMeta):
view_reqs = model.inference_view_requirements if model else \
self.view_requirements
view_reqs["state_in_{}".format(i)] = ViewRequirement(
"state_out_{}".format(i), shift=-1, space=space)
view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space)
"state_out_{}".format(i),
shift=-1,
batch_repeat_value=self.config.get("model", {}).get(
"max_seq_len", 1),
space=space)
view_reqs["state_out_{}".format(i)] = ViewRequirement(
space=space, used_for_training=True)
def clip_action(action, action_space):

View file

@ -19,7 +19,7 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
from ray.util import log_once
tf1, tf, tfv = try_import_tf()
@ -35,6 +35,7 @@ def pad_batch_to_sequences_of_same_size(
shuffle: bool = False,
batch_divisibility_req: int = 1,
feature_keys: Optional[List[str]] = None,
view_requirements: Optional[ViewRequirementsDict] = None,
):
"""Applies padding to `batch` so it's choppable into same-size sequences.
@ -55,6 +56,9 @@ def pad_batch_to_sequences_of_same_size(
feature_keys (Optional[List[str]]): An optional list of keys to apply
sequence-chopping to. If None, use all keys in batch that are not
"state_in/out_"-type keys.
view_requirements (Optional[ViewRequirementsDict]): An optional
Policy ViewRequirements dict to be able to infer whether
e.g. dynamic max'ing should be applied over the seq_lens.
"""
if batch_divisibility_req > 1:
meets_divisibility_reqs = (
@ -64,46 +68,65 @@ def pad_batch_to_sequences_of_same_size(
else:
meets_divisibility_reqs = True
# RNN-case.
states_already_reduced_to_init = False
# RNN/attention net case. Figure out whether we should apply dynamic
# max'ing over the list of sequence lengths.
if "state_in_0" in batch or "state_out_0" in batch:
dynamic_max = True
# Check, whether the state inputs have already been reduced to their
# init values at the beginning of each max_seq_len chunk.
if batch.seq_lens is not None and \
len(batch["state_in_0"]) == len(batch.seq_lens):
states_already_reduced_to_init = True
# RNN (or single timestep state-in): Set the max dynamically.
if view_requirements["state_in_0"].shift_from is None:
dynamic_max = True
# Attention Nets (state inputs are over some range): No dynamic maxing
# possible.
else:
dynamic_max = False
# Multi-agent case.
elif not meets_divisibility_reqs:
max_seq_len = batch_divisibility_req
dynamic_max = False
# Simple case: not RNN nor do we need to pad.
# Simple case: No RNN/attention net, nor do we need to pad.
else:
if shuffle:
batch.shuffle()
return
# RNN or multi-agent case.
# RNN, attention net, or multi-agent case.
state_keys = []
feature_keys_ = feature_keys or []
for k in batch.keys():
if "state_in_" in k:
for k, v in batch.items():
if k.startswith("state_in_"):
state_keys.append(k)
elif not feature_keys and "state_out_" not in k and k != "infos":
elif not feature_keys and not k.startswith("state_out_") and \
k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray):
feature_keys_.append(k)
feature_sequences, initial_states, seq_lens = \
chop_into_sequences(
batch[SampleBatch.EPS_ID],
batch[SampleBatch.UNROLL_ID],
batch[SampleBatch.AGENT_INDEX],
[batch[k] for k in feature_keys_],
[batch[k] for k in state_keys],
max_seq_len,
feature_columns=[batch[k] for k in feature_keys_],
state_columns=[batch[k] for k in state_keys],
episode_ids=batch.get(SampleBatch.EPS_ID),
unroll_ids=batch.get(SampleBatch.UNROLL_ID),
agent_indices=batch.get(SampleBatch.AGENT_INDEX),
seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")),
max_seq_len=max_seq_len,
dynamic_max=dynamic_max,
states_already_reduced_to_init=states_already_reduced_to_init,
shuffle=shuffle)
for i, k in enumerate(feature_keys_):
batch[k] = feature_sequences[i]
for i, k in enumerate(state_keys):
batch[k] = initial_states[i]
batch["seq_lens"] = seq_lens
batch["seq_lens"] = np.array(seq_lens)
if log_once("rnn_ma_feed_dict"):
logger.info("Padded input for RNN:\n\n{}\n".format(
logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
@ -157,18 +180,18 @@ def add_time_dimension(padded_inputs: TensorType,
return torch.reshape(padded_inputs, new_shape)
# NOTE: This function will be deprecated once chunks already come padded and
# correctly chopped from the _SampleCollector object (in time-major fashion
# or not). It is already no longer user iff `_use_trajectory_view_api` = True.
@DeveloperAPI
def chop_into_sequences(episode_ids,
unroll_ids,
agent_indices,
def chop_into_sequences(*,
feature_columns,
state_columns,
max_seq_len,
episode_ids=None,
unroll_ids=None,
agent_indices=None,
dynamic_max=True,
shuffle=False,
seq_lens=None,
states_already_reduced_to_init=False,
_extra_padding=0):
"""Truncate and pad experiences into fixed-length sequences.
@ -212,23 +235,24 @@ def chop_into_sequences(episode_ids,
[2, 3, 1]
"""
prev_id = None
seq_lens = []
seq_len = 0
unique_ids = np.add(
np.add(episode_ids, agent_indices),
np.array(unroll_ids, dtype=np.int64) << 32)
for uid in unique_ids:
if (prev_id is not None and uid != prev_id) or \
seq_len >= max_seq_len:
if seq_lens is None or len(seq_lens) == 0:
prev_id = None
seq_lens = []
seq_len = 0
unique_ids = np.add(
np.add(episode_ids, agent_indices),
np.array(unroll_ids, dtype=np.int64) << 32)
for uid in unique_ids:
if (prev_id is not None and uid != prev_id) or \
seq_len >= max_seq_len:
seq_lens.append(seq_len)
seq_len = 0
seq_len += 1
prev_id = uid
if seq_len:
seq_lens.append(seq_len)
seq_len = 0
seq_len += 1
prev_id = uid
if seq_len:
seq_lens.append(seq_len)
assert sum(seq_lens) == len(unique_ids)
seq_lens = np.array(seq_lens, dtype=np.int32)
seq_lens = np.array(seq_lens, dtype=np.int32)
assert sum(seq_lens) == len(feature_columns[0])
# Dynamically shrink max len as needed to optimize memory usage
if dynamic_max:
@ -252,18 +276,23 @@ def chop_into_sequences(episode_ids,
f_pad[seq_base + seq_offset] = f[i]
i += 1
seq_base += max_seq_len
assert i == len(unique_ids), f
assert i == len(f), f
feature_sequences.append(f_pad)
initial_states = []
for s in state_columns:
s = np.array(s)
s_init = []
i = 0
for len_ in seq_lens:
s_init.append(s[i])
i += len_
initial_states.append(np.array(s_init))
if states_already_reduced_to_init:
initial_states = state_columns
else:
initial_states = []
for s in state_columns:
# Skip unnecessary copy.
if not isinstance(s, np.ndarray):
s = np.array(s)
s_init = []
i = 0
for len_ in seq_lens:
s_init.append(s[i])
i += len_
initial_states.append(np.array(s_init))
if shuffle:
permutation = np.random.permutation(len(seq_lens))

View file

@ -61,6 +61,7 @@ class SampleBatch:
# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
self.seq_lens = kwargs.pop("_seq_lens", None)
self.dont_check_lens = kwargs.pop("_dont_check_lens", False)
self.max_seq_len = None
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.max_seq_len = max(self.seq_lens)
@ -76,8 +77,10 @@ class SampleBatch:
self.data[k] = np.array(v)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, \
"Data columns must be same length, but lens are {}".format(lengths)
if not self.dont_check_lens:
assert len(set(lengths)) == 1, \
"Data columns must be same length, but lens are " \
"{}".format(lengths)
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.count = sum(self.seq_lens)
else:
@ -117,7 +120,8 @@ class SampleBatch:
return SampleBatch(
out,
_seq_lens=np.array(seq_lens, dtype=np.int32),
_time_major=concat_samples[0].time_major)
_time_major=concat_samples[0].time_major,
_dont_check_lens=True)
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
@ -248,12 +252,35 @@ class SampleBatch:
SampleBatch: A new SampleBatch, which has a slice of this batch's
data.
"""
if self.time_major is not None:
if self.seq_lens is not None and len(self.seq_lens) > 0:
data = {k: v[start:end] for k, v in self.data.items()}
# Fix state_in_x data.
count = 0
state_start = None
seq_lens = None
for i, seq_len in enumerate(self.seq_lens):
count += seq_len
if count >= end:
state_idx = 0
state_key = "state_in_{}".format(state_idx)
while state_key in self.data:
data[state_key] = self.data[state_key][state_start:i +
1]
state_idx += 1
state_key = "state_in_{}".format(state_idx)
seq_lens = list(self.seq_lens[state_start:i]) + [
seq_len - (count - end)
]
assert sum(seq_lens) == (end - start)
break
elif state_start is None and count > start:
state_start = i
return SampleBatch(
{k: v[:, start:end]
for k, v in self.data.items()},
_seq_lens=self.seq_lens[start:end],
_time_major=self.time_major)
data,
_seq_lens=np.array(seq_lens, dtype=np.int32),
_time_major=self.time_major,
_dont_check_lens=True)
else:
return SampleBatch(
{k: v[start:end]

View file

@ -174,11 +174,6 @@ class TFPolicy(Policy):
raise ValueError(
"Number of state input and output tensors must match, got: "
"{} vs {}".format(self._state_inputs, self._state_outputs))
if len(self.get_initial_state()) != len(self._state_inputs):
raise ValueError(
"Length of initial state must match number of state inputs, "
"got: {} vs {}".format(self.get_initial_state(),
self._state_inputs))
if self._state_inputs and self._seq_lens is None:
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
@ -263,6 +258,11 @@ class TFPolicy(Policy):
(name, tf1.placeholders) needed for calculating the loss.
"""
self._loss_input_dict = dict(loss_inputs)
self._loss_input_dict_no_rnn = {
k: v
for k, v in self._loss_input_dict.items()
if (v not in self._state_inputs and v != self._seq_lens)
}
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
@ -791,11 +791,11 @@ class TFPolicy(Policy):
**fetches[LEARNER_STATS_KEY])
return fetches
def _get_loss_inputs_dict(self, batch, shuffle):
def _get_loss_inputs_dict(self, train_batch, shuffle):
"""Return a feed dict from a batch.
Args:
batch (SampleBatch): batch of data to derive inputs from
train_batch (SampleBatch): batch of data to derive inputs from.
shuffle (bool): whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
applying minibatch SGD after getting the outputs.
@ -806,28 +806,30 @@ class TFPolicy(Policy):
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
batch,
train_batch,
shuffle=shuffle,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self._batch_divisibility_req,
feature_keys=[
k for k in self._loss_input_dict.keys() if k != "seq_lens"
],
feature_keys=list(self._loss_input_dict_no_rnn.keys()),
view_requirements=self.view_requirements,
)
batch["is_training"] = True
# Mark the batch as "is_training" so the Model can use this
# information.
train_batch["is_training"] = True
# Build the feed dict from the batch.
feed_dict = {}
for key, placeholder in self._loss_input_dict.items():
feed_dict[placeholder] = batch[key]
feed_dict[placeholder] = train_batch[key]
state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs))
]
for key in state_keys:
feed_dict[self._loss_input_dict[key]] = batch[key]
feed_dict[self._loss_input_dict[key]] = train_batch[key]
if state_keys:
feed_dict[self._seq_lens] = batch["seq_lens"]
feed_dict[self._seq_lens] = train_batch["seq_lens"]
return feed_dict

View file

@ -345,12 +345,13 @@ class TorchPolicy(Policy):
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
# Get batch ready for RNNs, if applicable.
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
train_batch = self._lazy_tensor_dict(postprocessed_batch)

View file

@ -1,4 +1,5 @@
import gym
import numpy as np
from typing import List, Optional, Union
from ray.rllib.utils.framework import try_import_torch
@ -29,8 +30,9 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
shift: Union[int, List[int]] = 0,
shift: Union[int, str, List[int]] = 0,
index: Optional[int] = None,
batch_repeat_value: int = 1,
used_for_training: bool = True):
"""Initializes a ViewRequirement object.
@ -64,7 +66,19 @@ class ViewRequirement:
self.space = space if space is not None else gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.index = index
self.shift = shift
if isinstance(self.shift, (list, tuple)):
self.shift = np.array(self.shift)
# Special case: Providing a (probably larger) range of indices, e.g.
# "-100:0" (past 100 timesteps plus current one).
self.shift_from = self.shift_to = None
if isinstance(self.shift, str):
f, t = self.shift.split(":")
self.shift_from = int(f)
self.shift_to = int(t)
self.index = index
self.batch_repeat_value = batch_repeat_value
self.used_for_training = used_for_training

View file

@ -44,7 +44,8 @@ class TestAttentionNetLearning(unittest.TestCase):
"num_transformer_units": 1,
"attn_dim": 32,
"num_heads": 1,
"memory_tau": 5,
"memory_inference": 5,
"memory_training": 5,
"head_dim": 32,
"ff_hidden_dim": 32,
},
@ -71,7 +72,8 @@ class TestAttentionNetLearning(unittest.TestCase):
# "num_transformer_units": 1,
# "attn_dim": 64,
# "num_heads": 1,
# "memory_tau": 10,
# "memory_inference": 10,
# "memory_training": 10,
# "head_dim": 32,
# "ff_hidden_dim": 32,
# },

View file

@ -18,9 +18,13 @@ class TestLSTMUtils(unittest.TestCase):
f = [[101, 102, 103, 201, 202, 203, 204, 205],
[[101], [102], [103], [201], [202], [203], [204], [205]]]
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
np.ones_like(eps_ids),
agent_ids, f, s, 4)
f_pad, s_init, seq_lens = chop_into_sequences(
episode_ids=eps_ids,
unroll_ids=np.ones_like(eps_ids),
agent_indices=agent_ids,
feature_columns=f,
state_columns=s,
max_seq_len=4)
self.assertEqual([f.tolist() for f in f_pad], [
[101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
[[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
@ -35,9 +39,13 @@ class TestLSTMUtils(unittest.TestCase):
obs = np.ones((84, 84, 4))
f = [[obs, obs * 2, obs * 3]]
s = [[209, 208, 207]]
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
np.ones_like(eps_ids),
agent_ids, f, s, 4)
f_pad, s_init, seq_lens = chop_into_sequences(
episode_ids=eps_ids,
unroll_ids=np.ones_like(eps_ids),
agent_indices=agent_ids,
feature_columns=f,
state_columns=s,
max_seq_len=4)
self.assertEqual([f.tolist() for f in f_pad], [
np.array([obs, obs * 2, obs * 3]).tolist(),
])
@ -51,8 +59,13 @@ class TestLSTMUtils(unittest.TestCase):
f = [[101, 102, 103, 201, 202, 203, 204, 205],
[[101], [102], [103], [201], [202], [203], [204], [205]]]
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
_, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f,
s, 4)
_, _, seq_lens = chop_into_sequences(
episode_ids=eps_ids,
unroll_ids=batch_ids,
agent_indices=agent_ids,
feature_columns=f,
state_columns=s,
max_seq_len=4)
self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
def test_multi_agent(self):
@ -62,12 +75,12 @@ class TestLSTMUtils(unittest.TestCase):
[[101], [102], [103], [201], [202], [203], [204], [205]]]
s = [[209, 208, 207, 109, 108, 107, 106, 105]]
f_pad, s_init, seq_lens = chop_into_sequences(
eps_ids,
np.ones_like(eps_ids),
agent_ids,
f,
s,
4,
episode_ids=eps_ids,
unroll_ids=np.ones_like(eps_ids),
agent_indices=agent_ids,
feature_columns=f,
state_columns=s,
max_seq_len=4,
dynamic_max=False)
self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
self.assertEqual(len(f_pad[0]), 20)
@ -78,9 +91,13 @@ class TestLSTMUtils(unittest.TestCase):
agent_ids = [2, 2, 2]
f = [[1, 1, 1]]
s = [[1, 1, 1]]
f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
np.ones_like(eps_ids),
agent_ids, f, s, 4)
f_pad, s_init, seq_lens = chop_into_sequences(
episode_ids=eps_ids,
unroll_ids=np.ones_like(eps_ids),
agent_indices=agent_ids,
feature_columns=f,
state_columns=s,
max_seq_len=4)
self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
self.assertEqual(seq_lens.tolist(), [1, 2])

View file

@ -72,18 +72,23 @@ def minibatches(samples, sgd_minibatch_size):
i = 0
slices = []
if samples.seq_lens:
seq_no = 0
while i < samples.count:
seq_no_end = seq_no
actual_count = 0
while actual_count < sgd_minibatch_size and len(
samples.seq_lens) > seq_no_end:
actual_count += samples.seq_lens[seq_no_end]
seq_no_end += 1
slices.append((seq_no, seq_no_end))
i += actual_count
seq_no = seq_no_end
if samples.seq_lens is not None and len(samples.seq_lens) > 0:
start_pos = 0
minibatch_size = 0
idx = 0
while idx < len(samples.seq_lens):
seq_len = samples.seq_lens[idx]
minibatch_size += seq_len
# Complete minibatch -> Append to slices.
if minibatch_size >= sgd_minibatch_size:
slices.append((start_pos, start_pos + sgd_minibatch_size))
start_pos += sgd_minibatch_size
if minibatch_size > sgd_minibatch_size:
overhead = minibatch_size - sgd_minibatch_size
start_pos -= (seq_len - overhead)
idx -= 1
minibatch_size = 0
idx += 1
else:
while i < samples.count:
slices.append((i, i + sgd_minibatch_size))

View file

@ -100,6 +100,9 @@ ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
# Type of dict returned by get_weights() representing model weights.
ModelWeights = dict
# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls.
ModelInputDict = Dict[str, TensorType]
# Some kind of sample batch.
SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]