[RLlib] Added TransformerXL and "stabilized for RL" variant, GTrXL (#6470)

This commit is contained in:
gehring 2020-05-08 08:10:23 -04:00 committed by GitHub
parent 2c599dbf05
commit 7f14fb577d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 541 additions and 0 deletions

View file

@ -0,0 +1,173 @@
import argparse
import gym
import numpy as np
import ray
from ray import tune
from ray.tune import registry
from ray.rllib import models
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf import attention
from ray.rllib.models.tf import recurrent_tf_modelv2
from ray.rllib.examples.custom_keras_rnn_model import RepeatAfterMeEnv
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--stop", type=int, default=90)
parser.add_argument("--num-cpus", type=int, default=0)
class OneHot(gym.Wrapper):
def __init__(self, env):
super(OneHot, self).__init__(env)
self.observation_space = gym.spaces.Box(0., 1.,
(env.observation_space.n,))
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
return self._encode_obs(obs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
return self._encode_obs(obs), reward, done, info
def _encode_obs(self, obs):
new_obs = np.ones(self.env.observation_space.n)
new_obs[obs] = 1.0
return new_obs
class LookAndPush(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Discrete(5)
self._state = None
self._case = None
def reset(self):
self._state = 2
self._case = np.random.choice(2)
return self._state
def step(self, action):
assert self.action_space.contains(action)
if self._state == 4:
if action and self._case:
return self._state, 10., True, {}
else:
return self._state, -10, True, {}
else:
if action:
if self._state == 0:
self._state = 2
else:
self._state += 1
elif self._state == 2:
self._state = self._case
return self._state, -1, False, {}
class GRUTrXL(recurrent_tf_modelv2.RecurrentTFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(GRUTrXL, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.max_seq_len = model_config["max_seq_len"]
self.obs_dim = obs_space.shape[0]
input_layer = tf.keras.layers.Input(
shape=(self.max_seq_len, obs_space.shape[0]),
name="inputs",
)
trxl_out = attention.make_GRU_TrXL(
seq_length=model_config["max_seq_len"],
num_layers=model_config["custom_options"]["num_layers"],
attn_dim=model_config["custom_options"]["attn_dim"],
num_heads=model_config["custom_options"]["num_heads"],
head_dim=model_config["custom_options"]["head_dim"],
ff_hidden_dim=model_config["custom_options"]["ff_hidden_dim"],
)(input_layer)
# Postprocess TrXL output with another hidden layer and compute values
logits = tf.keras.layers.Dense(
self.num_outputs,
activation=tf.keras.activations.linear,
name="logits")(trxl_out)
values_out = tf.keras.layers.Dense(
1, activation=None, name="values")(trxl_out)
self.trxl_model = tf.keras.Model(
inputs=[input_layer],
outputs=[logits, values_out],
)
self.register_variables(self.trxl_model.variables)
self.trxl_model.summary()
def forward_rnn(self, inputs, state, seq_lens):
state = state[0]
# We assume state is the history of recent observations and append
# the current inputs to the end and only keep the most recent (up to
# max_seq_len). This allows us to deal with timestep-wise inference
# and full sequence training with the same logic.
state = tf.concat((state, inputs), axis=1)[:, -self.max_seq_len:]
logits, self._value_out = self.trxl_model(state)
in_T = tf.shape(inputs)[1]
logits = logits[:, -in_T:]
self._value_out = self._value_out[:, -in_T:]
return logits, [state]
def get_initial_state(self):
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
def value_function(self):
return tf.reshape(self._value_out, [-1])
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
models.ModelCatalog.register_custom_model("trxl", GRUTrXL)
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
registry.register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config={
"env": args.env,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.99,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "trxl",
"max_seq_len": 10,
"custom_options": {
"num_layers": 1,
"attn_dim": 10,
"num_heads": 1,
"head_dim": 10,
"ff_hidden_dim": 20,
},
},
})

View file

@ -0,0 +1,81 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from rllib.models.tf import attention
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def bit_shift_generator(seq_length, shift, batch_size):
while True:
values = np.array([0., 1.], dtype=np.float32)
seq = np.random.choice(values, (batch_size, seq_length, 1))
targets = np.squeeze(np.roll(seq, shift, axis=1).astype(np.int32))
targets[:, :shift] = 0
yield seq, targets
def make_model(seq_length, num_tokens, num_layers, attn_dim, num_heads,
head_dim, ff_hidden_dim):
return tf.keras.Sequential((
attention.make_TrXL(seq_length, num_layers, attn_dim, num_heads,
head_dim, ff_hidden_dim),
tf.keras.layers.Dense(num_tokens),
))
def train_loss(targets, outputs):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=outputs)
return tf.reduce_mean(loss)
def train_bit_shift(seq_length, num_iterations, print_every_n):
optimizer = tf.keras.optimizers.Adam(1e-3)
model = make_model(
seq_length,
num_tokens=2,
num_layers=1,
attn_dim=10,
num_heads=5,
head_dim=20,
ff_hidden_dim=20,
)
shift = 10
train_batch = 10
test_batch = 100
data_gen = bit_shift_generator(
seq_length, shift=shift, batch_size=train_batch)
test_gen = bit_shift_generator(
seq_length, shift=shift, batch_size=test_batch)
@tf.function
def update_step(inputs, targets):
loss_fn = lambda: train_loss(targets, model(inputs))
var_fn = lambda: model.trainable_variables
optimizer.minimize(loss_fn, var_fn)
for i, (inputs, targets) in zip(range(num_iterations), data_gen):
update_step(
tf.convert_to_tensor(inputs), tf.convert_to_tensor(targets))
if i % print_every_n == 0:
test_inputs, test_targets = next(test_gen)
print(i, train_loss(test_targets, model(test_inputs)))
if __name__ == "__main__":
tf.enable_eager_execution()
train_bit_shift(
seq_length=20,
num_iterations=2000,
print_every_n=200,
)

View file

@ -0,0 +1,287 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def relative_position_embedding(seq_length, out_dim):
inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim))
pos_offsets = tf.range(seq_length - 1., -1., -1.)
inputs = pos_offsets[:, None] * inverse_freq[None, :]
return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1)
def rel_shift(x):
# Transposed version of the shift approach implemented by Dai et al. 2019
# https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
x_size = tf.shape(x)
x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]])
x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]])
x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, x_size)
return x
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
# no bias or non-linearity
self._num_heads = num_heads
self._head_dim = head_dim
self._qkv_layer = tf.keras.layers.Dense(
3 * num_heads * head_dim, use_bias=False)
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(out_dim, use_bias=False))
def call(self, inputs):
L = tf.shape(inputs)[1] # length of segment
H = self._num_heads # number of attention heads
D = self._head_dim # attention head dimension
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
queries = queries[:, -L:] # only query based on the segment
queries = tf.reshape(queries, [-1, L, H, D])
keys = tf.reshape(keys, [-1, L, H, D])
values = tf.reshape(values, [-1, L, H, D])
score = tf.einsum("bihd,bjhd->bijh", queries, keys)
score = score / D ** 0.5
# causal mask of the same length as the sequence
mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
wmat = tf.nn.softmax(masked_score, axis=2)
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
return self._linear_layer(out)
class RelativeMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self,
out_dim,
num_heads,
head_dim,
rel_pos_encoder,
input_layernorm=False,
output_activation=None,
**kwargs):
super(RelativeMultiHeadAttention, self).__init__(**kwargs)
# no bias or non-linearity
self._num_heads = num_heads
self._head_dim = head_dim
self._qkv_layer = tf.keras.layers.Dense(
3 * num_heads * head_dim, use_bias=False)
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(
out_dim, use_bias=False, activation=output_activation))
self._uvar = self.add_weight(shape=(num_heads, head_dim))
self._vvar = self.add_weight(shape=(num_heads, head_dim))
self._pos_proj = tf.keras.layers.Dense(
num_heads * head_dim, use_bias=False)
self._rel_pos_encoder = rel_pos_encoder
self._input_layernorm = None
if input_layernorm:
self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1)
def call(self, inputs, memory=None):
L = tf.shape(inputs)[1] # length of segment
H = self._num_heads # number of attention heads
D = self._head_dim # attention head dimension
# length of the memory segment
M = memory.shape[0] if memory is not None else 0
if memory is not None:
inputs = np.concatenate(
(tf.stop_gradient(memory), inputs), axis=1)
if self._input_layernorm is not None:
inputs = self._input_layernorm(inputs)
qkv = self._qkv_layer(inputs)
queries, keys, values = tf.split(qkv, 3, -1)
queries = queries[:, -L:] # only query based on the segment
queries = tf.reshape(queries, [-1, L, H, D])
keys = tf.reshape(keys, [-1, L + M, H, D])
values = tf.reshape(values, [-1, L + M, H, D])
rel = self._pos_proj(self._rel_pos_encoder)
rel = tf.reshape(rel, [L, H, D])
score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, rel)
score = score + rel_shift(pos_score)
score = score / D**0.5
# causal mask of the same length as the sequence
mask = tf.sequence_mask(tf.range(M + 1, L + M + 1), dtype=score.dtype)
mask = mask[None, :, :, None]
masked_score = score * mask + 1e30 * (mask - 1.)
wmat = tf.nn.softmax(masked_score, axis=2)
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0))
return self._linear_layer(out)
class PositionwiseFeedforward(tf.keras.layers.Layer):
def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs):
super(PositionwiseFeedforward, self).__init__(**kwargs)
self._hidden_layer = tf.keras.layers.Dense(
hidden_dim,
activation=tf.nn.relu,
)
self._output_layer = tf.keras.layers.Dense(
out_dim, activation=output_activation)
def call(self, inputs, **kwargs):
del kwargs
output = self._hidden_layer(inputs)
return self._output_layer(output)
class SkipConnection(tf.keras.layers.Layer):
"""Skip connection layer.
If no fan-in layer is specified, then this layer behaves as a regular
residual layer.
"""
def __init__(self, layer, fan_in_layer=None, **kwargs):
super(SkipConnection, self).__init__(**kwargs)
self._fan_in_layer = fan_in_layer
self._layer = layer
def call(self, inputs, **kwargs):
del kwargs
outputs = self._layer(inputs)
if self._fan_in_layer is None:
outputs = outputs + inputs
else:
outputs = self._fan_in_layer((inputs, outputs))
return outputs
class GRUGate(tf.keras.layers.Layer):
def __init__(self, init_bias=0., **kwargs):
super(GRUGate, self).__init__(**kwargs)
self._init_bias = init_bias
def build(self, input_shape):
x_shape, y_shape = input_shape
if x_shape[-1] != y_shape[-1]:
raise ValueError(
"Both inputs to GRUGate must equal size last axis.")
self._w_r = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
self._w_z = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
self._w_h = self.add_weight(shape=(y_shape[-1], y_shape[-1]))
self._u_r = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
self._u_z = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
self._u_h = self.add_weight(shape=(x_shape[-1], x_shape[-1]))
def bias_initializer(shape, dtype):
return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype))
self._bias_z = self.add_weight(
shape=(x_shape[-1], ), initializer=bias_initializer)
def call(self, inputs, **kwargs):
x, y = inputs
r = (tf.tensordot(y, self._w_r, axes=1) + tf.tensordot(
x, self._u_r, axes=1))
r = tf.nn.sigmoid(r)
z = (tf.tensordot(y, self._w_z, axes=1) + tf.tensordot(
x, self._u_z, axes=1) + self._bias_z)
z = tf.nn.sigmoid(z)
h = (tf.tensordot(y, self._w_h, axes=1) + tf.tensordot(
(x * r), self._u_h, axes=1))
h = tf.nn.tanh(h)
return (1 - z) * x + z * h
def make_TrXL(seq_length, num_layers, attn_dim, num_heads, head_dim,
ff_hidden_dim):
pos_embedding = relative_position_embedding(seq_length, attn_dim)
layers = [tf.keras.layers.Dense(attn_dim)]
for _ in range(num_layers):
layers.append(
SkipConnection(
RelativeMultiHeadAttention(attn_dim, num_heads, head_dim,
pos_embedding)))
layers.append(tf.keras.layers.LayerNormalization(axis=-1))
layers.append(
SkipConnection(PositionwiseFeedforward(attn_dim, ff_hidden_dim)))
layers.append(tf.keras.layers.LayerNormalization(axis=-1))
return tf.keras.Sequential(layers)
def make_GRU_TrXL(seq_length,
num_layers,
attn_dim,
num_heads,
head_dim,
ff_hidden_dim,
init_gate_bias=2.):
# Default initial bias for the gate taken from
# Parisotto, Emilio, et al. "Stabilizing Transformers for Reinforcement Learning." arXiv preprint arXiv:1910.06764 (2019).
pos_embedding = relative_position_embedding(seq_length, attn_dim)
layers = [tf.keras.layers.Dense(attn_dim)]
for _ in range(num_layers):
layers.append(
SkipConnection(
RelativeMultiHeadAttention(
attn_dim,
num_heads,
head_dim,
pos_embedding,
input_layernorm=True,
output_activation=tf.nn.relu),
fan_in_layer=GRUGate(init_gate_bias),
))
layers.append(
SkipConnection(
tf.keras.Sequential(
(tf.keras.layers.LayerNormalization(axis=-1),
PositionwiseFeedforward(
attn_dim, ff_hidden_dim,
output_activation=tf.nn.relu))),
fan_in_layer=GRUGate(init_gate_bias),
))
return tf.keras.Sequential(layers)