From 7f14fb577d3ff758ee82855e610514258fef3089 Mon Sep 17 00:00:00 2001 From: gehring Date: Fri, 8 May 2020 08:10:23 -0400 Subject: [PATCH] [RLlib] Added TransformerXL and "stabilized for RL" variant, GTrXL (#6470) --- rllib/examples/rl_attention.py | 173 +++++++++++++++ rllib/examples/supervised_attention.py | 81 +++++++ rllib/models/tf/attention.py | 287 +++++++++++++++++++++++++ 3 files changed, 541 insertions(+) create mode 100644 rllib/examples/rl_attention.py create mode 100644 rllib/examples/supervised_attention.py create mode 100644 rllib/models/tf/attention.py diff --git a/rllib/examples/rl_attention.py b/rllib/examples/rl_attention.py new file mode 100644 index 000000000..a18d38be6 --- /dev/null +++ b/rllib/examples/rl_attention.py @@ -0,0 +1,173 @@ +import argparse + +import gym + +import numpy as np + +import ray +from ray import tune + +from ray.tune import registry + +from ray.rllib import models +from ray.rllib.utils import try_import_tf +from ray.rllib.models.tf import attention +from ray.rllib.models.tf import recurrent_tf_modelv2 +from ray.rllib.examples.custom_keras_rnn_model import RepeatAfterMeEnv +from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--env", type=str, default="RepeatAfterMeEnv") +parser.add_argument("--stop", type=int, default=90) +parser.add_argument("--num-cpus", type=int, default=0) + + +class OneHot(gym.Wrapper): + + def __init__(self, env): + super(OneHot, self).__init__(env) + self.observation_space = gym.spaces.Box(0., 1., + (env.observation_space.n,)) + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + return self._encode_obs(obs) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + return self._encode_obs(obs), reward, done, info + + def _encode_obs(self, obs): + new_obs = np.ones(self.env.observation_space.n) + new_obs[obs] = 1.0 + return new_obs + + +class LookAndPush(gym.Env): + def __init__(self): + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Discrete(5) + self._state = None + self._case = None + + def reset(self): + self._state = 2 + self._case = np.random.choice(2) + return self._state + + def step(self, action): + assert self.action_space.contains(action) + + if self._state == 4: + if action and self._case: + return self._state, 10., True, {} + else: + return self._state, -10, True, {} + else: + if action: + if self._state == 0: + self._state = 2 + else: + self._state += 1 + elif self._state == 2: + self._state = self._case + + return self._state, -1, False, {} + + +class GRUTrXL(recurrent_tf_modelv2.RecurrentTFModelV2): + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(GRUTrXL, self).__init__(obs_space, action_space, num_outputs, + model_config, name) + self.max_seq_len = model_config["max_seq_len"] + self.obs_dim = obs_space.shape[0] + input_layer = tf.keras.layers.Input( + shape=(self.max_seq_len, obs_space.shape[0]), + name="inputs", + ) + + trxl_out = attention.make_GRU_TrXL( + seq_length=model_config["max_seq_len"], + num_layers=model_config["custom_options"]["num_layers"], + attn_dim=model_config["custom_options"]["attn_dim"], + num_heads=model_config["custom_options"]["num_heads"], + head_dim=model_config["custom_options"]["head_dim"], + ff_hidden_dim=model_config["custom_options"]["ff_hidden_dim"], + )(input_layer) + + # Postprocess TrXL output with another hidden layer and compute values + logits = tf.keras.layers.Dense( + self.num_outputs, + activation=tf.keras.activations.linear, + name="logits")(trxl_out) + values_out = tf.keras.layers.Dense( + 1, activation=None, name="values")(trxl_out) + + self.trxl_model = tf.keras.Model( + inputs=[input_layer], + outputs=[logits, values_out], + ) + self.register_variables(self.trxl_model.variables) + self.trxl_model.summary() + + def forward_rnn(self, inputs, state, seq_lens): + state = state[0] + + # We assume state is the history of recent observations and append + # the current inputs to the end and only keep the most recent (up to + # max_seq_len). This allows us to deal with timestep-wise inference + # and full sequence training with the same logic. + state = tf.concat((state, inputs), axis=1)[:, -self.max_seq_len:] + logits, self._value_out = self.trxl_model(state) + + in_T = tf.shape(inputs)[1] + logits = logits[:, -in_T:] + self._value_out = self._value_out[:, -in_T:] + + return logits, [state] + + def get_initial_state(self): + return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init(num_cpus=args.num_cpus or None) + models.ModelCatalog.register_custom_model("trxl", GRUTrXL) + registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c)) + registry.register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv()) + registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush())) + tune.run( + args.run, + stop={"episode_reward_mean": args.stop}, + config={ + "env": args.env, + "env_config": { + "repeat_delay": 2, + }, + "gamma": 0.99, + "num_workers": 0, + "num_envs_per_worker": 20, + "entropy_coeff": 0.001, + "num_sgd_iter": 5, + "vf_loss_coeff": 1e-5, + "model": { + "custom_model": "trxl", + "max_seq_len": 10, + "custom_options": { + "num_layers": 1, + "attn_dim": 10, + "num_heads": 1, + "head_dim": 10, + "ff_hidden_dim": 20, + }, + }, + }) diff --git a/rllib/examples/supervised_attention.py b/rllib/examples/supervised_attention.py new file mode 100644 index 000000000..e55ed8e1a --- /dev/null +++ b/rllib/examples/supervised_attention.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from rllib.models.tf import attention +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + + +def bit_shift_generator(seq_length, shift, batch_size): + while True: + values = np.array([0., 1.], dtype=np.float32) + seq = np.random.choice(values, (batch_size, seq_length, 1)) + targets = np.squeeze(np.roll(seq, shift, axis=1).astype(np.int32)) + targets[:, :shift] = 0 + yield seq, targets + + +def make_model(seq_length, num_tokens, num_layers, attn_dim, num_heads, + head_dim, ff_hidden_dim): + + return tf.keras.Sequential(( + attention.make_TrXL(seq_length, num_layers, attn_dim, num_heads, + head_dim, ff_hidden_dim), + tf.keras.layers.Dense(num_tokens), + )) + + +def train_loss(targets, outputs): + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=targets, logits=outputs) + return tf.reduce_mean(loss) + + +def train_bit_shift(seq_length, num_iterations, print_every_n): + + optimizer = tf.keras.optimizers.Adam(1e-3) + + model = make_model( + seq_length, + num_tokens=2, + num_layers=1, + attn_dim=10, + num_heads=5, + head_dim=20, + ff_hidden_dim=20, + ) + + shift = 10 + train_batch = 10 + test_batch = 100 + data_gen = bit_shift_generator( + seq_length, shift=shift, batch_size=train_batch) + test_gen = bit_shift_generator( + seq_length, shift=shift, batch_size=test_batch) + + @tf.function + def update_step(inputs, targets): + loss_fn = lambda: train_loss(targets, model(inputs)) + var_fn = lambda: model.trainable_variables + optimizer.minimize(loss_fn, var_fn) + + for i, (inputs, targets) in zip(range(num_iterations), data_gen): + update_step( + tf.convert_to_tensor(inputs), tf.convert_to_tensor(targets)) + + if i % print_every_n == 0: + test_inputs, test_targets = next(test_gen) + print(i, train_loss(test_targets, model(test_inputs))) + + +if __name__ == "__main__": + tf.enable_eager_execution() + train_bit_shift( + seq_length=20, + num_iterations=2000, + print_every_n=200, + ) diff --git a/rllib/models/tf/attention.py b/rllib/models/tf/attention.py new file mode 100644 index 000000000..1f9a54b6a --- /dev/null +++ b/rllib/models/tf/attention.py @@ -0,0 +1,287 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + + +def relative_position_embedding(seq_length, out_dim): + inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) + pos_offsets = tf.range(seq_length - 1., -1., -1.) + inputs = pos_offsets[:, None] * inverse_freq[None, :] + return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) + + +def rel_shift(x): + # Transposed version of the shift approach implemented by Dai et al. 2019 + # https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 + x_size = tf.shape(x) + + x = tf.pad(x, [[0, 0], [0, 0], [1, 0], [0, 0]]) + x = tf.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]]) + x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1]) + x = tf.reshape(x, x_size) + + return x + + +class MultiHeadAttention(tf.keras.layers.Layer): + + def __init__(self, out_dim, num_heads, head_dim, **kwargs): + super(MultiHeadAttention, self).__init__(**kwargs) + + # no bias or non-linearity + self._num_heads = num_heads + self._head_dim = head_dim + self._qkv_layer = tf.keras.layers.Dense( + 3 * num_heads * head_dim, use_bias=False) + self._linear_layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(out_dim, use_bias=False)) + + def call(self, inputs): + L = tf.shape(inputs)[1] # length of segment + H = self._num_heads # number of attention heads + D = self._head_dim # attention head dimension + + qkv = self._qkv_layer(inputs) + + queries, keys, values = tf.split(qkv, 3, -1) + queries = queries[:, -L:] # only query based on the segment + + queries = tf.reshape(queries, [-1, L, H, D]) + keys = tf.reshape(keys, [-1, L, H, D]) + values = tf.reshape(values, [-1, L, H, D]) + + score = tf.einsum("bihd,bjhd->bijh", queries, keys) + score = score / D ** 0.5 + + # causal mask of the same length as the sequence + mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype) + mask = mask[None, :, :, None] + + masked_score = score * mask + 1e30 * (mask - 1.) + wmat = tf.nn.softmax(masked_score, axis=2) + + out = tf.einsum("bijh,bjhd->bihd", wmat, values) + out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0)) + return self._linear_layer(out) + + +class RelativeMultiHeadAttention(tf.keras.layers.Layer): + def __init__(self, + out_dim, + num_heads, + head_dim, + rel_pos_encoder, + input_layernorm=False, + output_activation=None, + **kwargs): + super(RelativeMultiHeadAttention, self).__init__(**kwargs) + + # no bias or non-linearity + self._num_heads = num_heads + self._head_dim = head_dim + self._qkv_layer = tf.keras.layers.Dense( + 3 * num_heads * head_dim, use_bias=False) + self._linear_layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense( + out_dim, use_bias=False, activation=output_activation)) + + self._uvar = self.add_weight(shape=(num_heads, head_dim)) + self._vvar = self.add_weight(shape=(num_heads, head_dim)) + + self._pos_proj = tf.keras.layers.Dense( + num_heads * head_dim, use_bias=False) + self._rel_pos_encoder = rel_pos_encoder + + self._input_layernorm = None + if input_layernorm: + self._input_layernorm = tf.keras.layers.LayerNormalization(axis=-1) + + def call(self, inputs, memory=None): + L = tf.shape(inputs)[1] # length of segment + H = self._num_heads # number of attention heads + D = self._head_dim # attention head dimension + + # length of the memory segment + M = memory.shape[0] if memory is not None else 0 + + if memory is not None: + inputs = np.concatenate( + (tf.stop_gradient(memory), inputs), axis=1) + + if self._input_layernorm is not None: + inputs = self._input_layernorm(inputs) + + qkv = self._qkv_layer(inputs) + + queries, keys, values = tf.split(qkv, 3, -1) + queries = queries[:, -L:] # only query based on the segment + + queries = tf.reshape(queries, [-1, L, H, D]) + keys = tf.reshape(keys, [-1, L + M, H, D]) + values = tf.reshape(values, [-1, L + M, H, D]) + + rel = self._pos_proj(self._rel_pos_encoder) + rel = tf.reshape(rel, [L, H, D]) + + score = tf.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) + pos_score = tf.einsum("bihd,jhd->bijh", queries + self._vvar, rel) + score = score + rel_shift(pos_score) + score = score / D**0.5 + + # causal mask of the same length as the sequence + mask = tf.sequence_mask(tf.range(M + 1, L + M + 1), dtype=score.dtype) + mask = mask[None, :, :, None] + + masked_score = score * mask + 1e30 * (mask - 1.) + wmat = tf.nn.softmax(masked_score, axis=2) + + out = tf.einsum("bijh,bjhd->bihd", wmat, values) + out = tf.reshape(out, tf.concat((tf.shape(out)[:2], [H * D]), axis=0)) + return self._linear_layer(out) + + +class PositionwiseFeedforward(tf.keras.layers.Layer): + + def __init__(self, out_dim, hidden_dim, output_activation=None, **kwargs): + super(PositionwiseFeedforward, self).__init__(**kwargs) + + self._hidden_layer = tf.keras.layers.Dense( + hidden_dim, + activation=tf.nn.relu, + ) + self._output_layer = tf.keras.layers.Dense( + out_dim, activation=output_activation) + + def call(self, inputs, **kwargs): + del kwargs + output = self._hidden_layer(inputs) + return self._output_layer(output) + + +class SkipConnection(tf.keras.layers.Layer): + """Skip connection layer. + + If no fan-in layer is specified, then this layer behaves as a regular + residual layer. + """ + + def __init__(self, layer, fan_in_layer=None, **kwargs): + super(SkipConnection, self).__init__(**kwargs) + self._fan_in_layer = fan_in_layer + self._layer = layer + + def call(self, inputs, **kwargs): + del kwargs + outputs = self._layer(inputs) + if self._fan_in_layer is None: + outputs = outputs + inputs + else: + outputs = self._fan_in_layer((inputs, outputs)) + + return outputs + + +class GRUGate(tf.keras.layers.Layer): + + def __init__(self, init_bias=0., **kwargs): + super(GRUGate, self).__init__(**kwargs) + self._init_bias = init_bias + + def build(self, input_shape): + x_shape, y_shape = input_shape + if x_shape[-1] != y_shape[-1]: + raise ValueError( + "Both inputs to GRUGate must equal size last axis.") + + self._w_r = self.add_weight(shape=(y_shape[-1], y_shape[-1])) + self._w_z = self.add_weight(shape=(y_shape[-1], y_shape[-1])) + self._w_h = self.add_weight(shape=(y_shape[-1], y_shape[-1])) + self._u_r = self.add_weight(shape=(x_shape[-1], x_shape[-1])) + self._u_z = self.add_weight(shape=(x_shape[-1], x_shape[-1])) + self._u_h = self.add_weight(shape=(x_shape[-1], x_shape[-1])) + + def bias_initializer(shape, dtype): + return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype)) + + self._bias_z = self.add_weight( + shape=(x_shape[-1], ), initializer=bias_initializer) + + def call(self, inputs, **kwargs): + x, y = inputs + r = (tf.tensordot(y, self._w_r, axes=1) + tf.tensordot( + x, self._u_r, axes=1)) + r = tf.nn.sigmoid(r) + + z = (tf.tensordot(y, self._w_z, axes=1) + tf.tensordot( + x, self._u_z, axes=1) + self._bias_z) + z = tf.nn.sigmoid(z) + + h = (tf.tensordot(y, self._w_h, axes=1) + tf.tensordot( + (x * r), self._u_h, axes=1)) + h = tf.nn.tanh(h) + + return (1 - z) * x + z * h + + +def make_TrXL(seq_length, num_layers, attn_dim, num_heads, head_dim, + ff_hidden_dim): + pos_embedding = relative_position_embedding(seq_length, attn_dim) + + layers = [tf.keras.layers.Dense(attn_dim)] + for _ in range(num_layers): + layers.append( + SkipConnection( + RelativeMultiHeadAttention(attn_dim, num_heads, head_dim, + pos_embedding))) + layers.append(tf.keras.layers.LayerNormalization(axis=-1)) + + layers.append( + SkipConnection(PositionwiseFeedforward(attn_dim, ff_hidden_dim))) + layers.append(tf.keras.layers.LayerNormalization(axis=-1)) + + return tf.keras.Sequential(layers) + + +def make_GRU_TrXL(seq_length, + num_layers, + attn_dim, + num_heads, + head_dim, + ff_hidden_dim, + init_gate_bias=2.): + # Default initial bias for the gate taken from + # Parisotto, Emilio, et al. "Stabilizing Transformers for Reinforcement Learning." arXiv preprint arXiv:1910.06764 (2019). + pos_embedding = relative_position_embedding(seq_length, attn_dim) + + layers = [tf.keras.layers.Dense(attn_dim)] + for _ in range(num_layers): + layers.append( + SkipConnection( + RelativeMultiHeadAttention( + attn_dim, + num_heads, + head_dim, + pos_embedding, + input_layernorm=True, + output_activation=tf.nn.relu), + fan_in_layer=GRUGate(init_gate_bias), + )) + + layers.append( + SkipConnection( + tf.keras.Sequential( + (tf.keras.layers.LayerNormalization(axis=-1), + PositionwiseFeedforward( + attn_dim, ff_hidden_dim, + output_activation=tf.nn.relu))), + fan_in_layer=GRUGate(init_gate_bias), + )) + + return tf.keras.Sequential(layers)