import numpy as np from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.tf.recurrent_net import RecurrentNetwork from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch tf = try_import_tf() torch, nn = try_import_torch() class RNNModel(RecurrentNetwork): """Example of using the Keras functional API to define a RNN model.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name, hiddens_size=256, cell_size=64): super(RNNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.cell_size = cell_size # Define input layers input_layer = tf.keras.layers.Input( shape=(None, obs_space.shape[0]), name="inputs") state_in_h = tf.keras.layers.Input(shape=(cell_size, ), name="h") state_in_c = tf.keras.layers.Input(shape=(cell_size, ), name="c") seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32) # Preprocess observation with a hidden layer and send to LSTM cell dense1 = tf.keras.layers.Dense( hiddens_size, activation=tf.nn.relu, name="dense1")(input_layer) lstm_out, state_h, state_c = tf.keras.layers.LSTM( cell_size, return_sequences=True, return_state=True, name="lstm")( inputs=dense1, mask=tf.sequence_mask(seq_in), initial_state=[state_in_h, state_in_c]) # Postprocess LSTM output with another hidden layer and compute values logits = tf.keras.layers.Dense( self.num_outputs, activation=tf.keras.activations.linear, name="logits")(lstm_out) values = tf.keras.layers.Dense( 1, activation=None, name="values")(lstm_out) # Create the RNN model self.rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[logits, values, state_h, state_c]) self.register_variables(self.rnn_model.variables) self.rnn_model.summary() @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] + state) return model_out, [h, c] @override(ModelV2) def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ] @override(ModelV2) def value_function(self): return tf.reshape(self._value_out, [-1]) class TorchRNNModel(TorchRNN): def __init__(self, obs_space, action_space, num_outputs, model_config, name, fc_size=64, lstm_state_size=256): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.obs_size = get_preprocessor(obs_space)(obs_space).size self.fc_size = fc_size self.lstm_state_size = lstm_state_size # Build the Module from fc + LSTM + 2xfc (action + value outs). self.fc1 = nn.Linear(self.obs_size, self.fc_size) self.lstm = nn.LSTM( self.fc_size, self.lstm_state_size, batch_first=True) self.action_branch = nn.Linear(self.lstm_state_size, num_outputs) self.value_branch = nn.Linear(self.lstm_state_size, 1) # Holds the current "base" output (before logits layer). self._features = None @override(ModelV2) def get_initial_state(self): # Place hidden states on same device as model. h = [ self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0), self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0) ] return h @override(ModelV2) def value_function(self): assert self._features is not None, "must call forward() first" return torch.reshape(self.value_branch(self._features), [-1]) @override(TorchRNN) def forward_rnn(self, inputs, state, seq_lens): """Feeds `inputs` (B x T x ..) through the Gru Unit. Returns the resulting outputs as a sequence (B x T x ...). Values are stored in self._cur_value in simple (B) shape (where B contains both the B and T dims!). Returns: NN Outputs (B x T x ...) as sequence. The state batches as a List of two items (c- and h-states). """ x = nn.functional.relu(self.fc1(inputs)) self._features, [h, c] = self.lstm( x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]) action_out = self.action_branch(self._features) return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]