mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Added TransformerXL and "stabilized for RL" variant, GTrXL (#6470)
This commit is contained in:
parent
2c599dbf05
commit
7f14fb577d
3 changed files with 541 additions and 0 deletions
173
rllib/examples/rl_attention.py
Normal file
173
rllib/examples/rl_attention.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
import argparse
|
||||
|
||||
import gym
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
from ray.tune import registry
|
||||
|
||||
from ray.rllib import models
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.models.tf import attention
|
||||
from ray.rllib.models.tf import recurrent_tf_modelv2
|
||||
from ray.rllib.examples.custom_keras_rnn_model import RepeatAfterMeEnv
|
||||
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
|
||||
parser.add_argument("--stop", type=int, default=90)
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
|
||||
|
||||
class OneHot(gym.Wrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
super(OneHot, self).__init__(env)
|
||||
self.observation_space = gym.spaces.Box(0., 1.,
|
||||
(env.observation_space.n,))
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs = self.env.reset(**kwargs)
|
||||
return self._encode_obs(obs)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
return self._encode_obs(obs), reward, done, info
|
||||
|
||||
def _encode_obs(self, obs):
|
||||
new_obs = np.ones(self.env.observation_space.n)
|
||||
new_obs[obs] = 1.0
|
||||
return new_obs
|
||||
|
||||
|
||||
class LookAndPush(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Discrete(5)
|
||||
self._state = None
|
||||
self._case = None
|
||||
|
||||
def reset(self):
|
||||
self._state = 2
|
||||
self._case = np.random.choice(2)
|
||||
return self._state
|
||||
|
||||
def step(self, action):
|
||||
assert self.action_space.contains(action)
|
||||
|
||||
if self._state == 4:
|
||||
if action and self._case:
|
||||
return self._state, 10., True, {}
|
||||
else:
|
||||
return self._state, -10, True, {}
|
||||
else:
|
||||
if action:
|
||||
if self._state == 0:
|
||||
self._state = 2
|
||||
else:
|
||||
self._state += 1
|
||||
elif self._state == 2:
|
||||
self._state = self._case
|
||||
|
||||
return self._state, -1, False, {}
|
||||
|
||||
|
||||
class GRUTrXL(recurrent_tf_modelv2.RecurrentTFModelV2):
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(GRUTrXL, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
self.max_seq_len = model_config["max_seq_len"]
|
||||
self.obs_dim = obs_space.shape[0]
|
||||
input_layer = tf.keras.layers.Input(
|
||||
shape=(self.max_seq_len, obs_space.shape[0]),
|
||||
name="inputs",
|
||||
)
|
||||
|
||||
trxl_out = attention.make_GRU_TrXL(
|
||||
seq_length=model_config["max_seq_len"],
|
||||
num_layers=model_config["custom_options"]["num_layers"],
|
||||
attn_dim=model_config["custom_options"]["attn_dim"],
|
||||
num_heads=model_config["custom_options"]["num_heads"],
|
||||
head_dim=model_config["custom_options"]["head_dim"],
|
||||
ff_hidden_dim=model_config["custom_options"]["ff_hidden_dim"],
|
||||
)(input_layer)
|
||||
|
||||
# Postprocess TrXL output with another hidden layer and compute values
|
||||
logits = tf.keras.layers.Dense(
|
||||
self.num_outputs,
|
||||
activation=tf.keras.activations.linear,
|
||||
name="logits")(trxl_out)
|
||||
values_out = tf.keras.layers.Dense(
|
||||
1, activation=None, name="values")(trxl_out)
|
||||
|
||||
self.trxl_model = tf.keras.Model(
|
||||
inputs=[input_layer],
|
||||
outputs=[logits, values_out],
|
||||
)
|
||||
self.register_variables(self.trxl_model.variables)
|
||||
self.trxl_model.summary()
|
||||
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
state = state[0]
|
||||
|
||||
# We assume state is the history of recent observations and append
|
||||
# the current inputs to the end and only keep the most recent (up to
|
||||
# max_seq_len). This allows us to deal with timestep-wise inference
|
||||
# and full sequence training with the same logic.
|
||||
state = tf.concat((state, inputs), axis=1)[:, -self.max_seq_len:]
|
||||
logits, self._value_out = self.trxl_model(state)
|
||||
|
||||
in_T = tf.shape(inputs)[1]
|
||||
logits = logits[:, -in_T:]
|
||||
self._value_out = self._value_out[:, -in_T:]
|
||||
|
||||
return logits, [state]
|
||||
|
||||
def get_initial_state(self):
|
||||
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
|
||||
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
models.ModelCatalog.register_custom_model("trxl", GRUTrXL)
|
||||
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
|
||||
registry.register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
|
||||
registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={"episode_reward_mean": args.stop},
|
||||
config={
|
||||
"env": args.env,
|
||||
"env_config": {
|
||||
"repeat_delay": 2,
|
||||
},
|
||||
"gamma": 0.99,
|
||||
"num_workers": 0,
|
||||
"num_envs_per_worker": 20,
|
||||
"entropy_coeff": 0.001,
|
||||
"num_sgd_iter": 5,
|
||||
"vf_loss_coeff": 1e-5,
|
||||
"model": {
|
||||
"custom_model": "trxl",
|
||||
"max_seq_len": 10,
|
||||
"custom_options": {
|
||||
"num_layers": 1,
|
||||
"attn_dim": 10,
|
||||
"num_heads": 1,
|
||||
"head_dim": 10,
|
||||
"ff_hidden_dim": 20,
|
||||
},
|
||||
},
|
||||
})
|
81
rllib/examples/supervised_attention.py
Normal file
81
rllib/examples/supervised_attention.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rllib.models.tf import attention
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def bit_shift_generator(seq_length, shift, batch_size):
|
||||
while True:
|
||||
values = np.array([0., 1.], dtype=np.float32)
|
||||
seq = np.random.choice(values, (batch_size, seq_length, 1))
|
||||
targets = np.squeeze(np.roll(seq, shift, axis=1).astype(np.int32))
|
||||
targets[:, :shift] = 0
|
||||
yield seq, targets
|
||||
|
||||
|
||||
def make_model(seq_length, num_tokens, num_layers, attn_dim, num_heads,
|
||||
head_dim, ff_hidden_dim):
|
||||
|
||||
return tf.keras.Sequential((
|
||||
attention.make_TrXL(seq_length, num_layers, attn_dim, num_heads,
|
||||
head_dim, ff_hidden_dim),
|
||||
tf.keras.layers.Dense(num_tokens),
|
||||
))
|
||||
|
||||
|
||||
def train_loss(targets, outputs):
|
||||
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=targets, logits=outputs)
|
||||
return tf.reduce_mean(loss)
|
||||
|
||||
|
||||
def train_bit_shift(seq_length, num_iterations, print_every_n):
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(1e-3)
|
||||
|
||||
model = make_model(
|
||||
seq_length,
|
||||
num_tokens=2,
|
||||
num_layers=1,
|
||||
attn_dim=10,
|
||||
num_heads=5,
|
||||
head_dim=20,
|
||||
ff_hidden_dim=20,
|
||||
)
|
||||
|
||||
shift = 10
|
||||
train_batch = 10
|
||||
test_batch = 100
|
||||
data_gen = bit_shift_generator(
|
||||
seq_length, shift=shift, batch_size=train_batch)
|
||||
test_gen = bit_shift_generator(
|
||||
seq_length, shift=shift, batch_size=test_batch)
|
||||
|
||||
@tf.function
|
||||
def update_step(inputs, targets):
|
||||
loss_fn = lambda: train_loss(targets, model(inputs))
|
||||
var_fn = lambda: model.trainable_variables
|
||||
optimizer.minimize(loss_fn, var_fn)
|
||||
|
||||
for i, (inputs, targets) in zip(range(num_iterations), data_gen):
|
||||
update_step(
|
||||
tf.convert_to_tensor(inputs), tf.convert_to_tensor(targets))
|
||||
|
||||
if i % print_every_n == 0:
|
||||
test_inputs, test_targets = next(test_gen)
|
||||
print(i, train_loss(test_targets, model(test_inputs)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.enable_eager_execution()
|
||||
train_bit_shift(
|
||||
seq_length=20,
|
||||
num_iterations=2000,
|
||||
print_every_n=200,
|
||||
)
|
287
rllib/models/tf/attention.py
Normal file
287
rllib/models/tf/attention.py
Normal file
|
@ -0,0 +1,287 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def relative_position_embedding(seq_length, out_dim):
|
||||
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
|
||||
pos_offsets = tf.range(seq_length - 1., -1., -1.)
|
||||
inputs = pos_offsets[:, None] * inverse_freq[None, :]
|
||||
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
|
||||
|
||||
|
||||
def rel_shift(x):
|
||||
# Transposed version of the shift approach implemented by Dai et al. 2019
|
||||
# https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
|
||||
x_size = tf.shape(x)
|
||||
|
||||
x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]])
|
||||
x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]])
|
||||
x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
|
||||
x = tf.reshape(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
|
||||
super(MultiHeadAttention, self).__init__(**kwargs)
|
||||
|
||||
# no bias or non-linearity
|
||||
self._num_heads = num_heads
|
||||
self._head_dim = head_dim
|
||||
self._qkv_layer = tf.keras.layers.Dense(
|
||||
3 * num_heads * head_dim, use_bias=False)
|
||||
self._linear_layer = tf.keras.layers.TimeDistributed(
|
||||
tf.keras.layers.Dense(out_dim, use_bias=False))
|
||||
|
||||
def call(self, inputs):
|
||||
L = tf.shape(inputs)[1] # length of segment
|
||||
H = self._num_heads # number of attention heads
|
||||
D = self._head_dim # attention head dimension
|
||||
|
||||
qkv = self._qkv_layer(inputs)
|
||||
|
||||
queries, keys, values = tf.split(qkv, 3, -1)
|
||||
queries = queries[:, -L:] # only query based on the segment
|
||||
|
||||
queries = tf.reshape(queries, [-1, L, H, D])
|
||||
keys = tf.reshape(keys, [-1, L, H, D])
|
||||
values = tf.reshape(values, [-1, L, H, D])
|
||||
|
||||
score = tf.einsum("bihd,bjhd->bijh", queries, keys)
|
||||
score = score / D ** 0.5
|
||||
|
||||
# causal mask of the same length as the sequence
|
||||
mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype)
|
||||
mask = mask[None, :, :, None]
|
||||
|
||||
masked_score = score * mask + 1e30 * (mask - 1.)
|
||||
wmat = tf.nn.softmax(masked_score, axis=2)
|
||||
|
||||
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
|
||||
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
|
||||
return self._linear_layer(out)
|
||||
|
||||
|
||||
class RelativeMultiHeadAttention(tf.keras.layers.Layer):
|
||||
def __init__(self,
|
||||
out_dim,
|
||||
num_heads,
|
||||
head_dim,
|
||||
rel_pos_encoder,
|
||||
input_layernorm=False,
|
||||
output_activation=None,
|
||||
**kwargs):
|
||||
super(RelativeMultiHeadAttention, self).__init__(**kwargs)
|
||||
|
||||
# no bias or non-linearity
|
||||
self._num_heads = num_heads
|
||||
self._head_dim = head_dim
|
||||
self._qkv_layer = tf.keras.layers.Dense(
|
||||
3 * num_heads * head_dim, use_bias=False)
|
||||
self._linear_layer = tf.keras.layers.TimeDistributed(
|
||||
tf.keras.layers.Dense(
|
||||
out_dim, use_bias=False, activation=output_activation))
|
||||
|
||||
self._uvar = self.add_weight(shape=(num_heads, head_dim))
|
||||
self._vvar = self.add_weight(shape=(num_heads, head_dim))
|
||||
|
||||
self._pos_proj = tf.keras.layers.Dense(
|
||||
num_heads * head_dim, use_bias=False)
|
||||
self._rel_pos_encoder = rel_pos_encoder
|
||||
|
||||
self._input_layernorm = None
|
||||
if input_layernorm:
|
||||
self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1)
|
||||
|
||||
def call(self, inputs, memory=None):
|
||||
L = tf.shape(inputs)[1] # length of segment
|
||||
H = self._num_heads # number of attention heads
|
||||
D = self._head_dim # attention head dimension
|
||||
|
||||
# length of the memory segment
|
||||
M = memory.shape[0] if memory is not None else 0
|
||||
|
||||
if memory is not None:
|
||||
inputs = np.concatenate(
|
||||
(tf.stop_gradient(memory), inputs), axis=1)
|
||||
|
||||
if self._input_layernorm is not None:
|
||||
inputs = self._input_layernorm(inputs)
|
||||
|
||||
qkv = self._qkv_layer(inputs)
|
||||
|
||||
queries, keys, values = tf.split(qkv, 3, -1)
|
||||
queries = queries[:, -L:] # only query based on the segment
|
||||
|
||||
queries = tf.reshape(queries, [-1, L, H, D])
|
||||
keys = tf.reshape(keys, [-1, L + M, H, D])
|
||||
values = tf.reshape(values, [-1, L + M, H, D])
|
||||
|
||||
rel = self._pos_proj(self._rel_pos_encoder)
|
||||
rel = tf.reshape(rel, [L, H, D])
|
||||
|
||||
score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
|
||||
pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, rel)
|
||||
score = score + rel_shift(pos_score)
|
||||
score = score / D**0.5
|
||||
|
||||
# causal mask of the same length as the sequence
|
||||
mask = tf.sequence_mask(tf.range(M + 1, L + M + 1), dtype=score.dtype)
|
||||
mask = mask[None, :, :, None]
|
||||
|
||||
masked_score = score * mask + 1e30 * (mask - 1.)
|
||||
wmat = tf.nn.softmax(masked_score, axis=2)
|
||||
|
||||
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
|
||||
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
|
||||
return self._linear_layer(out)
|
||||
|
||||
|
||||
class PositionwiseFeedforward(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
|
||||
super(PositionwiseFeedforward, self).__init__(**kwargs)
|
||||
|
||||
self._hidden_layer = tf.keras.layers.Dense(
|
||||
hidden_dim,
|
||||
activation=tf.nn.relu,
|
||||
)
|
||||
self._output_layer = tf.keras.layers.Dense(
|
||||
out_dim, activation=output_activation)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
del kwargs
|
||||
output = self._hidden_layer(inputs)
|
||||
return self._output_layer(output)
|
||||
|
||||
|
||||
class SkipConnection(tf.keras.layers.Layer):
|
||||
"""Skip connection layer.
|
||||
|
||||
If no fan-in layer is specified, then this layer behaves as a regular
|
||||
residual layer.
|
||||
"""
|
||||
|
||||
def __init__(self, layer, fan_in_layer=None, **kwargs):
|
||||
super(SkipConnection, self).__init__(**kwargs)
|
||||
self._fan_in_layer = fan_in_layer
|
||||
self._layer = layer
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
del kwargs
|
||||
outputs = self._layer(inputs)
|
||||
if self._fan_in_layer is None:
|
||||
outputs = outputs + inputs
|
||||
else:
|
||||
outputs = self._fan_in_layer((inputs, outputs))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class GRUGate(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, init_bias=0., **kwargs):
|
||||
super(GRUGate, self).__init__(**kwargs)
|
||||
self._init_bias = init_bias
|
||||
|
||||
def build(self, input_shape):
|
||||
x_shape, y_shape = input_shape
|
||||
if x_shape[-1] != y_shape[-1]:
|
||||
raise ValueError(
|
||||
"Both inputs to GRUGate must equal size last axis.")
|
||||
|
||||
self._w_r = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
|
||||
self._w_z = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
|
||||
self._w_h = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
|
||||
self._u_r = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
|
||||
self._u_z = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
|
||||
self._u_h = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
|
||||
|
||||
def bias_initializer(shape, dtype):
|
||||
return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype))
|
||||
|
||||
self._bias_z = self.add_weight(
|
||||
shape=(x_shape[-1], ), initializer=bias_initializer)
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
x, y = inputs
|
||||
r = (tf.tensordot(y, self._w_r, axes=1) + tf.tensordot(
|
||||
x, self._u_r, axes=1))
|
||||
r = tf.nn.sigmoid(r)
|
||||
|
||||
z = (tf.tensordot(y, self._w_z, axes=1) + tf.tensordot(
|
||||
x, self._u_z, axes=1) + self._bias_z)
|
||||
z = tf.nn.sigmoid(z)
|
||||
|
||||
h = (tf.tensordot(y, self._w_h, axes=1) + tf.tensordot(
|
||||
(x * r), self._u_h, axes=1))
|
||||
h = tf.nn.tanh(h)
|
||||
|
||||
return (1 - z) * x + z * h
|
||||
|
||||
|
||||
def make_TrXL(seq_length, num_layers, attn_dim, num_heads, head_dim,
|
||||
ff_hidden_dim):
|
||||
pos_embedding = relative_position_embedding(seq_length, attn_dim)
|
||||
|
||||
layers = [tf.keras.layers.Dense(attn_dim)]
|
||||
for _ in range(num_layers):
|
||||
layers.append(
|
||||
SkipConnection(
|
||||
RelativeMultiHeadAttention(attn_dim, num_heads, head_dim,
|
||||
pos_embedding)))
|
||||
layers.append(tf.keras.layers.LayerNormalization(axis=-1))
|
||||
|
||||
layers.append(
|
||||
SkipConnection(PositionwiseFeedforward(attn_dim, ff_hidden_dim)))
|
||||
layers.append(tf.keras.layers.LayerNormalization(axis=-1))
|
||||
|
||||
return tf.keras.Sequential(layers)
|
||||
|
||||
|
||||
def make_GRU_TrXL(seq_length,
|
||||
num_layers,
|
||||
attn_dim,
|
||||
num_heads,
|
||||
head_dim,
|
||||
ff_hidden_dim,
|
||||
init_gate_bias=2.):
|
||||
# Default initial bias for the gate taken from
|
||||
# Parisotto, Emilio, et al. "Stabilizing Transformers for Reinforcement Learning." arXiv preprint arXiv:1910.06764 (2019).
|
||||
pos_embedding = relative_position_embedding(seq_length, attn_dim)
|
||||
|
||||
layers = [tf.keras.layers.Dense(attn_dim)]
|
||||
for _ in range(num_layers):
|
||||
layers.append(
|
||||
SkipConnection(
|
||||
RelativeMultiHeadAttention(
|
||||
attn_dim,
|
||||
num_heads,
|
||||
head_dim,
|
||||
pos_embedding,
|
||||
input_layernorm=True,
|
||||
output_activation=tf.nn.relu),
|
||||
fan_in_layer=GRUGate(init_gate_bias),
|
||||
))
|
||||
|
||||
layers.append(
|
||||
SkipConnection(
|
||||
tf.keras.Sequential(
|
||||
(tf.keras.layers.LayerNormalization(axis=-1),
|
||||
PositionwiseFeedforward(
|
||||
attn_dim, ff_hidden_dim,
|
||||
output_activation=tf.nn.relu))),
|
||||
fan_in_layer=GRUGate(init_gate_bias),
|
||||
))
|
||||
|
||||
return tf.keras.Sequential(layers)
|
Loading…
Add table
Reference in a new issue