mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
336 lines
14 KiB
Python
336 lines
14 KiB
Python
"""
|
|
[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
|
|
Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
|
|
https://arxiv.org/pdf/1706.03762.pdf
|
|
[2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto
|
|
et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf
|
|
[3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
|
|
Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
|
|
https://www.aclweb.org/anthology/P19-1285.pdf
|
|
"""
|
|
import numpy as np
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \
|
|
SkipConnection
|
|
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf = try_import_tf()
|
|
|
|
|
|
# TODO(sven): Use RLlib's FCNet instead.
|
|
class PositionwiseFeedforward(tf.keras.layers.Layer):
|
|
"""A 2x linear layer with ReLU activation in between described in [1].
|
|
|
|
Each timestep coming from the attention head will be passed through this
|
|
layer separately.
|
|
"""
|
|
|
|
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self._hidden_layer = tf.keras.layers.Dense(
|
|
hidden_dim,
|
|
activation=tf.nn.relu,
|
|
)
|
|
|
|
self._output_layer = tf.keras.layers.Dense(
|
|
out_dim, activation=output_activation)
|
|
|
|
def call(self, inputs, **kwargs):
|
|
del kwargs
|
|
output = self._hidden_layer(inputs)
|
|
return self._output_layer(output)
|
|
|
|
|
|
class TrXLNet(RecurrentNetwork):
|
|
"""A TrXL net Model described in [1]."""
|
|
|
|
def __init__(self, observation_space, action_space, num_outputs,
|
|
model_config, name, num_transformer_units, attn_dim,
|
|
num_heads, head_dim, ff_hidden_dim):
|
|
"""Initializes a TfXLNet object.
|
|
|
|
Args:
|
|
num_transformer_units (int): The number of Transformer repeats to
|
|
use (denoted L in [2]).
|
|
attn_dim (int): The input and output dimensions of one Transformer
|
|
unit.
|
|
num_heads (int): The number of attention heads to use in parallel.
|
|
Denoted as `H` in [3].
|
|
head_dim (int): The dimension of a single(!) head.
|
|
Denoted as `d` in [3].
|
|
ff_hidden_dim (int): The dimension of the hidden layer within
|
|
the position-wise MLP (after the multi-head attention block
|
|
within one Transformer unit). This is the size of the first
|
|
of the two layers within the PositionwiseFeedforward. The
|
|
second layer always has size=`attn_dim`.
|
|
"""
|
|
|
|
super().__init__(observation_space, action_space, num_outputs,
|
|
model_config, name)
|
|
|
|
self.num_transformer_units = num_transformer_units
|
|
self.attn_dim = attn_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.max_seq_len = model_config["max_seq_len"]
|
|
self.obs_dim = observation_space.shape[0]
|
|
|
|
pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim)
|
|
|
|
inputs = tf.keras.layers.Input(
|
|
shape=(self.max_seq_len, self.obs_dim), name="inputs")
|
|
E_out = tf.keras.layers.Dense(attn_dim)(inputs)
|
|
|
|
for _ in range(self.num_transformer_units):
|
|
MHA_out = SkipConnection(
|
|
RelativeMultiHeadAttention(
|
|
out_dim=attn_dim,
|
|
num_heads=num_heads,
|
|
head_dim=head_dim,
|
|
rel_pos_encoder=pos_embedding,
|
|
input_layernorm=False,
|
|
output_activation=None),
|
|
fan_in_layer=None)(E_out)
|
|
E_out = SkipConnection(
|
|
PositionwiseFeedforward(attn_dim, ff_hidden_dim))(MHA_out)
|
|
E_out = tf.keras.layers.LayerNormalization(axis=-1)(E_out)
|
|
|
|
# Postprocess TrXL output with another hidden layer and compute values.
|
|
logits = tf.keras.layers.Dense(
|
|
self.num_outputs,
|
|
activation=tf.keras.activations.linear,
|
|
name="logits")(E_out)
|
|
|
|
self.base_model = tf.keras.models.Model([inputs], [logits])
|
|
self.register_variables(self.base_model.variables)
|
|
|
|
@override(RecurrentNetwork)
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
# To make Attention work with current RLlib's ModelV2 API:
|
|
# We assume `state` is the history of L recent observations (all
|
|
# concatenated into one tensor) and append the current inputs to the
|
|
# end and only keep the most recent (up to `max_seq_len`). This allows
|
|
# us to deal with timestep-wise inference and full sequence training
|
|
# within the same logic.
|
|
observations = state[0]
|
|
observations = tf.concat(
|
|
(observations, inputs), axis=1)[:, -self.max_seq_len:]
|
|
logits = self.base_model([observations])
|
|
T = tf.shape(inputs)[1] # Length of input segment (time).
|
|
logits = logits[:, -T:]
|
|
|
|
return logits, [observations]
|
|
|
|
@override(RecurrentNetwork)
|
|
def get_initial_state(self):
|
|
# State is the T last observations concat'd together into one Tensor.
|
|
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
|
|
# tau timesteps).
|
|
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
|
|
|
|
|
|
class GTrXLNet(RecurrentNetwork):
|
|
"""A GTrXL net Model described in [2].
|
|
|
|
This is still in an experimental phase.
|
|
Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
|
|
For an example script, see: `ray/rllib/examples/attention_net.py`.
|
|
|
|
To use this network as a replacement for an RNN, configure your Trainer
|
|
as follows:
|
|
|
|
Examples:
|
|
>> config["model"]["custom_model"] = GTrXLNet
|
|
>> config["model"]["max_seq_len"] = 10
|
|
>> config["model"]["custom_options"] = {
|
|
>> num_transformer_units=1,
|
|
>> attn_dim=32,
|
|
>> num_heads=2,
|
|
>> memory_tau=50,
|
|
>> etc..
|
|
>> }
|
|
"""
|
|
|
|
def __init__(self,
|
|
observation_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
num_transformer_units,
|
|
attn_dim,
|
|
num_heads,
|
|
memory_tau,
|
|
head_dim,
|
|
ff_hidden_dim,
|
|
init_gate_bias=2.0):
|
|
"""Initializes a GTrXLNet.
|
|
|
|
Args:
|
|
num_transformer_units (int): The number of Transformer repeats to
|
|
use (denoted L in [2]).
|
|
attn_dim (int): The input and output dimensions of one Transformer
|
|
unit.
|
|
num_heads (int): The number of attention heads to use in parallel.
|
|
Denoted as `H` in [3].
|
|
memory_tau (int): The number of timesteps to store in each
|
|
transformer block's memory M (concat'd over time and fed into
|
|
next transformer block as input).
|
|
head_dim (int): The dimension of a single(!) head.
|
|
Denoted as `d` in [3].
|
|
ff_hidden_dim (int): The dimension of the hidden layer within
|
|
the position-wise MLP (after the multi-head attention block
|
|
within one Transformer unit). This is the size of the first
|
|
of the two layers within the PositionwiseFeedforward. The
|
|
second layer always has size=`attn_dim`.
|
|
init_gate_bias (float): Initial bias values for the GRU gates (two
|
|
GRUs per Transformer unit, one after the MHA, one after the
|
|
position-wise MLP).
|
|
"""
|
|
|
|
super().__init__(observation_space, action_space, num_outputs,
|
|
model_config, name)
|
|
|
|
self.num_transformer_units = num_transformer_units
|
|
self.attn_dim = attn_dim
|
|
self.num_heads = num_heads
|
|
self.memory_tau = memory_tau
|
|
self.head_dim = head_dim
|
|
self.max_seq_len = model_config["max_seq_len"]
|
|
self.obs_dim = observation_space.shape[0]
|
|
|
|
# Constant (non-trainable) sinusoid rel pos encoding matrix.
|
|
Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
|
|
self.attn_dim)
|
|
|
|
# Raw observation input.
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(self.max_seq_len, self.obs_dim), name="inputs")
|
|
memory_ins = [
|
|
tf.keras.layers.Input(
|
|
shape=(self.memory_tau, self.attn_dim),
|
|
dtype=tf.float32,
|
|
name="memory_in_{}".format(i))
|
|
for i in range(self.num_transformer_units)
|
|
]
|
|
|
|
# Map observation dim to input/output transformer (attention) dim.
|
|
E_out = tf.keras.layers.Dense(self.attn_dim)(input_layer)
|
|
# Output, collected and concat'd to build the internal, tau-len
|
|
# Memory units used for additional contextual information.
|
|
memory_outs = [E_out]
|
|
|
|
# 2) Create L Transformer blocks according to [2].
|
|
for i in range(self.num_transformer_units):
|
|
# RelativeMultiHeadAttention part.
|
|
MHA_out = SkipConnection(
|
|
RelativeMultiHeadAttention(
|
|
out_dim=self.attn_dim,
|
|
num_heads=num_heads,
|
|
head_dim=head_dim,
|
|
rel_pos_encoder=Phi,
|
|
input_layernorm=True,
|
|
output_activation=tf.nn.relu),
|
|
fan_in_layer=GRUGate(init_gate_bias),
|
|
name="mha_{}".format(i + 1))(
|
|
E_out, memory=memory_ins[i])
|
|
# Position-wise MLP part.
|
|
E_out = SkipConnection(
|
|
tf.keras.Sequential(
|
|
(tf.keras.layers.LayerNormalization(axis=-1),
|
|
PositionwiseFeedforward(
|
|
out_dim=self.attn_dim,
|
|
hidden_dim=ff_hidden_dim,
|
|
output_activation=tf.nn.relu))),
|
|
fan_in_layer=GRUGate(init_gate_bias),
|
|
name="pos_wise_mlp_{}".format(i + 1))(MHA_out)
|
|
# Output of position-wise MLP == E(l-1), which is concat'd
|
|
# to the current Mem block (M(l-1)) to yield E~(l-1), which is then
|
|
# used by the next transformer block.
|
|
memory_outs.append(E_out)
|
|
|
|
# Postprocess TrXL output with another hidden layer and compute values.
|
|
logits = tf.keras.layers.Dense(
|
|
self.num_outputs,
|
|
activation=tf.keras.activations.linear,
|
|
name="logits")(E_out)
|
|
|
|
self._value_out = None
|
|
values_out = tf.keras.layers.Dense(
|
|
1, activation=None, name="values")(E_out)
|
|
|
|
self.trxl_model = tf.keras.Model(
|
|
inputs=[input_layer] + memory_ins,
|
|
outputs=[logits, values_out] + memory_outs[:-1])
|
|
|
|
self.register_variables(self.trxl_model.variables)
|
|
self.trxl_model.summary()
|
|
|
|
@override(RecurrentNetwork)
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
# To make Attention work with current RLlib's ModelV2 API:
|
|
# We assume `state` is the history of L recent observations (all
|
|
# concatenated into one tensor) and append the current inputs to the
|
|
# end and only keep the most recent (up to `max_seq_len`). This allows
|
|
# us to deal with timestep-wise inference and full sequence training
|
|
# within the same logic.
|
|
observations = state[0]
|
|
memory = state[1:]
|
|
|
|
observations = tf.concat(
|
|
(observations, inputs), axis=1)[:, -self.max_seq_len:]
|
|
all_out = self.trxl_model([observations] + memory)
|
|
logits, self._value_out = all_out[0], all_out[1]
|
|
memory_outs = all_out[2:]
|
|
# If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
|
|
if self.memory_tau > self.max_seq_len:
|
|
memory_outs = [
|
|
tf.concat(
|
|
[memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
|
|
axis=1) for i, m in enumerate(memory_outs)
|
|
]
|
|
else:
|
|
memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]
|
|
|
|
T = tf.shape(inputs)[1] # Length of input segment (time).
|
|
logits = logits[:, -T:]
|
|
self._value_out = self._value_out[:, -T:]
|
|
|
|
return logits, [observations] + memory_outs
|
|
|
|
@override(RecurrentNetwork)
|
|
def get_initial_state(self):
|
|
# State is the T last observations concat'd together into one Tensor.
|
|
# Plus all Transformer blocks' E(l) outputs concat'd together (up to
|
|
# tau timesteps).
|
|
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \
|
|
[np.zeros((self.memory_tau, self.attn_dim), np.float32)
|
|
for _ in range(self.num_transformer_units)]
|
|
|
|
@override(ModelV2)
|
|
def value_function(self):
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
def relative_position_embedding(seq_length, out_dim):
|
|
"""Creates a [seq_length x seq_length] matrix for rel. pos encoding.
|
|
|
|
Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
|
|
matrix.
|
|
|
|
Args:
|
|
seq_length (int): The max. sequence length (time axis).
|
|
out_dim (int): The number of nodes to go into the first Tranformer
|
|
layer with.
|
|
|
|
Returns:
|
|
tf.Tensor: The encoding matrix Phi.
|
|
"""
|
|
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
|
|
pos_offsets = tf.range(seq_length - 1., -1., -1.)
|
|
inputs = pos_offsets[:, None] * inverse_freq[None, :]
|
|
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
|