mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
224 lines
9.2 KiB
Python
224 lines
9.2 KiB
Python
"""
|
|
[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
|
|
Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
|
|
https://arxiv.org/pdf/1706.03762.pdf
|
|
[2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto
|
|
et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf
|
|
[3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
|
|
Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
|
|
https://www.aclweb.org/anthology/P19-1285.pdf
|
|
"""
|
|
import numpy as np
|
|
import gym
|
|
from gym.spaces import Box
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.torch.misc import SlimFC
|
|
from ray.rllib.models.torch.modules import GRUGate, \
|
|
RelativeMultiHeadAttention, SkipConnection
|
|
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class GTrXLNet(RecurrentNetwork, nn.Module):
|
|
"""A GTrXL net Model described in [2].
|
|
|
|
This is still in an experimental phase.
|
|
Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
|
|
For an example script, see: `ray/rllib/examples/attention_net.py`.
|
|
|
|
To use this network as a replacement for an RNN, configure your Trainer
|
|
as follows:
|
|
|
|
Examples:
|
|
>> config["model"]["custom_model"] = GTrXLNet
|
|
>> config["model"]["max_seq_len"] = 10
|
|
>> config["model"]["custom_model_config"] = {
|
|
>> num_transformer_units=1,
|
|
>> attn_dim=32,
|
|
>> num_heads=2,
|
|
>> memory_tau=50,
|
|
>> etc..
|
|
>> }
|
|
"""
|
|
|
|
def __init__(self,
|
|
observation_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
num_outputs: int,
|
|
model_config: ModelConfigDict,
|
|
name: str,
|
|
num_transformer_units: int,
|
|
attn_dim: int,
|
|
num_heads: int,
|
|
memory_inference: int,
|
|
memory_training: int,
|
|
head_dim: int,
|
|
ff_hidden_dim: int,
|
|
init_gate_bias: float = 2.0):
|
|
"""Initializes a GTrXLNet.
|
|
|
|
Args:
|
|
num_transformer_units (int): The number of Transformer repeats to
|
|
use (denoted L in [2]).
|
|
attn_dim (int): The input and output dimensions of one Transformer
|
|
unit.
|
|
num_heads (int): The number of attention heads to use in parallel.
|
|
Denoted as `H` in [3].
|
|
memory_inference (int): The number of timesteps to concat (time
|
|
axis) and feed into the next transformer unit as inference
|
|
input. The first transformer unit will receive this number of
|
|
past observations (plus the current one), instead.
|
|
memory_training (int): The number of timesteps to concat (time
|
|
axis) and feed into the next transformer unit as training
|
|
input (plus the actual input sequence of len=max_seq_len).
|
|
The first transformer unit will receive this number of
|
|
past observations (plus the input sequence), instead.
|
|
head_dim (int): The dimension of a single(!) head.
|
|
Denoted as `d` in [3].
|
|
ff_hidden_dim (int): The dimension of the hidden layer within
|
|
the position-wise MLP (after the multi-head attention block
|
|
within one Transformer unit). This is the size of the first
|
|
of the two layers within the PositionwiseFeedforward. The
|
|
second layer always has size=`attn_dim`.
|
|
init_gate_bias (float): Initial bias values for the GRU gates (two
|
|
GRUs per Transformer unit, one after the MHA, one after the
|
|
position-wise MLP).
|
|
"""
|
|
|
|
super().__init__(observation_space, action_space, num_outputs,
|
|
model_config, name)
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
self.num_transformer_units = num_transformer_units
|
|
self.attn_dim = attn_dim
|
|
self.num_heads = num_heads
|
|
self.memory_inference = memory_inference
|
|
self.memory_training = memory_training
|
|
self.head_dim = head_dim
|
|
self.max_seq_len = model_config["max_seq_len"]
|
|
self.obs_dim = observation_space.shape[0]
|
|
|
|
self.linear_layer = SlimFC(
|
|
in_size=self.obs_dim, out_size=self.attn_dim)
|
|
|
|
self.layers = [self.linear_layer]
|
|
|
|
attention_layers = []
|
|
# 2) Create L Transformer blocks according to [2].
|
|
for i in range(self.num_transformer_units):
|
|
# RelativeMultiHeadAttention part.
|
|
MHA_layer = SkipConnection(
|
|
RelativeMultiHeadAttention(
|
|
in_dim=self.attn_dim,
|
|
out_dim=self.attn_dim,
|
|
num_heads=num_heads,
|
|
head_dim=head_dim,
|
|
input_layernorm=True,
|
|
output_activation=nn.ReLU),
|
|
fan_in_layer=GRUGate(self.attn_dim, init_gate_bias))
|
|
|
|
# Position-wise MultiLayerPerceptron part.
|
|
E_layer = SkipConnection(
|
|
nn.Sequential(
|
|
torch.nn.LayerNorm(self.attn_dim),
|
|
SlimFC(
|
|
in_size=self.attn_dim,
|
|
out_size=ff_hidden_dim,
|
|
use_bias=False,
|
|
activation_fn=nn.ReLU),
|
|
SlimFC(
|
|
in_size=ff_hidden_dim,
|
|
out_size=self.attn_dim,
|
|
use_bias=False,
|
|
activation_fn=nn.ReLU)),
|
|
fan_in_layer=GRUGate(self.attn_dim, init_gate_bias))
|
|
|
|
# Build a list of all attanlayers in order.
|
|
attention_layers.extend([MHA_layer, E_layer])
|
|
|
|
# Create a Sequential such that all parameters inside the attention
|
|
# layers are automatically registered with this top-level model.
|
|
self.attention_layers = nn.Sequential(*attention_layers)
|
|
self.layers.extend(attention_layers)
|
|
|
|
# Postprocess GTrXL output with another hidden layer.
|
|
self.logits = SlimFC(
|
|
in_size=self.attn_dim,
|
|
out_size=self.num_outputs,
|
|
activation_fn=nn.ReLU)
|
|
|
|
# Value function used by all RLlib Torch RL implementations.
|
|
self._value_out = None
|
|
self.values_out = SlimFC(
|
|
in_size=self.attn_dim, out_size=1, activation_fn=None)
|
|
|
|
# Setup inference view (`memory-inference` x past observations +
|
|
# current one (0))
|
|
# 1 to `num_transformer_units`: Memory data (one per transformer unit).
|
|
for i in range(self.num_transformer_units):
|
|
space = Box(-1.0, 1.0, shape=(self.attn_dim, ))
|
|
self.inference_view_requirements["state_in_{}".format(i)] = \
|
|
ViewRequirement(
|
|
"state_out_{}".format(i),
|
|
shift="-{}:-1".format(self.memory_inference),
|
|
# Repeat the incoming state every max-seq-len times.
|
|
batch_repeat_value=self.max_seq_len,
|
|
space=space)
|
|
self.inference_view_requirements["state_out_{}".format(i)] = \
|
|
ViewRequirement(
|
|
space=space,
|
|
used_for_training=False)
|
|
|
|
@override(ModelV2)
|
|
def forward(self, input_dict, state: List[TensorType],
|
|
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
|
assert seq_lens is not None
|
|
|
|
# Add the needed batch rank (tf Models' Input requires this).
|
|
observations = input_dict[SampleBatch.OBS]
|
|
# Add the time dim to observations.
|
|
B = len(seq_lens)
|
|
T = observations.shape[0] // B
|
|
observations = torch.reshape(observations,
|
|
[-1, T] + list(observations.shape[1:]))
|
|
|
|
all_out = observations
|
|
memory_outs = []
|
|
for i in range(len(self.layers)):
|
|
# MHA layers which need memory passed in.
|
|
if i % 2 == 1:
|
|
all_out = self.layers[i](all_out, memory=state[i // 2])
|
|
# Either self.linear_layer (initial obs -> attn. dim layer) or
|
|
# MultiLayerPerceptrons. The output of these layers is always the
|
|
# memory for the next forward pass.
|
|
else:
|
|
all_out = self.layers[i](all_out)
|
|
memory_outs.append(all_out)
|
|
|
|
# Discard last output (not needed as a memory since it's the last
|
|
# layer).
|
|
memory_outs = memory_outs[:-1]
|
|
|
|
logits = self.logits(all_out)
|
|
self._value_out = self.values_out(all_out)
|
|
|
|
return torch.reshape(logits, [-1, self.num_outputs]), [
|
|
torch.reshape(m, [-1, self.attn_dim]) for m in memory_outs
|
|
]
|
|
|
|
# TODO: (sven) Deprecate this once trajectory view API has fully matured.
|
|
@override(RecurrentNetwork)
|
|
def get_initial_state(self) -> List[np.ndarray]:
|
|
return []
|
|
|
|
@override(ModelV2)
|
|
def value_function(self) -> TensorType:
|
|
return torch.reshape(self._value_out, [-1])
|