import numpy as np from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() @DeveloperAPI class RecurrentNetwork(TFModelV2): """Helper class to simplify implementing RNN models with TFModelV2. Instead of implementing forward(), you can implement forward_rnn() which takes batches with the time dimension added already. Here is an example implementation for a subclass ``MyRNNClass(RecurrentNetwork)``:: def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) cell_size = 256 # Define input layers input_layer = tf.keras.layers.Input( shape=(None, obs_space.shape[0])) state_in_h = tf.keras.layers.Input(shape=(256, )) state_in_c = tf.keras.layers.Input(shape=(256, )) seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32) # Send to LSTM cell lstm_out, state_h, state_c = tf.keras.layers.LSTM( cell_size, return_sequences=True, return_state=True, name="lstm")( inputs=input_layer, mask=tf.sequence_mask(seq_in), initial_state=[state_in_h, state_in_c]) output_layer = tf.keras.layers.Dense(...)(lstm_out) # Create the RNN model self.rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[output_layer, state_h, state_c]) self.register_variables(self.rnn_model.variables) self.rnn_model.summary() """ @override(ModelV2) def forward(self, input_dict, state, seq_lens): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" assert seq_lens is not None padded_inputs = input_dict["obs_flat"] max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0] output, new_state = self.forward_rnn( add_time_dimension( padded_inputs, max_seq_len=max_seq_len, framework="tf"), state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state def forward_rnn(self, inputs, state, seq_lens): """Call the model with the given input tensors and state. Args: inputs (dict): observation tensor with shape [B, T, obs_size]. state (list): list of state tensors, each with shape [B, T, size]. seq_lens (Tensor): 1d tensor holding input sequence lengths. Returns: (outputs, new_state): The model output tensor of shape [B, T, num_outputs] and the list of new state tensors each with shape [B, size]. Sample implementation for the ``MyRNNClass`` example:: def forward_rnn(self, inputs, state, seq_lens): model_out, h, c = self.rnn_model([inputs, seq_lens] + state) return model_out, [h, c] """ raise NotImplementedError("You must implement this for a RNN model") def get_initial_state(self): """Get the initial recurrent state values for the model. Returns: list of np.array objects, if any Sample implementation for the ``MyRNNClass`` example:: def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ] """ raise NotImplementedError("You must implement this for a RNN model") class LSTMWrapper(RecurrentNetwork): """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm. """ def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(LSTMWrapper, self).__init__(obs_space, action_space, None, model_config, name) self.cell_size = model_config["lstm_cell_size"] self.use_prev_action_reward = model_config[ "lstm_use_prev_action_reward"] self.action_dim = int(np.product(action_space.shape)) # Add prev-action/reward nodes to input to LSTM. if self.use_prev_action_reward: self.num_outputs += 1 + self.action_dim # Define input layers. input_layer = tf.keras.layers.Input( shape=(None, self.num_outputs), name="inputs") self.num_outputs = num_outputs state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h") state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c") seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32) # Preprocess observation with a hidden layer and send to LSTM cell lstm_out, state_h, state_c = tf.keras.layers.LSTM( self.cell_size, return_sequences=True, return_state=True, name="lstm")( inputs=input_layer, mask=tf.sequence_mask(seq_in), initial_state=[state_in_h, state_in_c]) # Postprocess LSTM output with another hidden layer and compute values logits = tf.keras.layers.Dense( self.num_outputs, activation=tf.keras.activations.linear, name="logits")(lstm_out) values = tf.keras.layers.Dense( 1, activation=None, name="values")(lstm_out) # Create the RNN model self._rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[logits, values, state_h, state_c]) self.register_variables(self._rnn_model.variables) self._rnn_model.summary() @override(RecurrentNetwork) def forward(self, input_dict, state, seq_lens): assert seq_lens is not None # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) # Concat. prev-action/reward if required. if self.model_config["lstm_use_prev_action_reward"]: wrapped_out = tf.concat( [ wrapped_out, tf.reshape( tf.cast(input_dict[SampleBatch.PREV_ACTIONS], tf.float32), [-1, self.action_dim]), tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]), ], axis=1) # Then through our LSTM. input_dict["obs_flat"] = wrapped_out return super().forward(input_dict, state, seq_lens) @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state) return model_out, [h, c] @override(ModelV2) def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ] @override(ModelV2) def value_function(self): return tf.reshape(self._value_out, [-1])