import numpy as np from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.recurrent_net import RecurrentNetwork from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf, try_import_torch tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() class MobileV2PlusRNNModel(RecurrentNetwork): """A conv. + recurrent keras net example using a pre-trained MobileNet.""" def __init__( self, obs_space, action_space, num_outputs, model_config, name, cnn_shape ): super(MobileV2PlusRNNModel, self).__init__( obs_space, action_space, num_outputs, model_config, name ) self.cell_size = 16 visual_size = cnn_shape[0] * cnn_shape[1] * cnn_shape[2] state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h") state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c") seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32) inputs = tf.keras.layers.Input(shape=(None, visual_size), name="visual_inputs") input_visual = inputs input_visual = tf.reshape( input_visual, [-1, cnn_shape[0], cnn_shape[1], cnn_shape[2]] ) cnn_input = tf.keras.layers.Input(shape=cnn_shape, name="cnn_input") cnn_model = tf.keras.applications.mobilenet_v2.MobileNetV2( alpha=1.0, include_top=True, weights=None, input_tensor=cnn_input, pooling=None, ) vision_out = cnn_model(input_visual) vision_out = tf.reshape( vision_out, [-1, tf.shape(inputs)[1], vision_out.shape.as_list()[-1]] ) lstm_out, state_h, state_c = tf.keras.layers.LSTM( self.cell_size, return_sequences=True, return_state=True, name="lstm" )( inputs=vision_out, mask=tf.sequence_mask(seq_in), initial_state=[state_in_h, state_in_c], ) # Postprocess LSTM output with another hidden layer and compute values. logits = tf.keras.layers.Dense( self.num_outputs, activation=tf.keras.activations.linear, name="logits" )(lstm_out) values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out) # Create the RNN model self.rnn_model = tf.keras.Model( inputs=[inputs, seq_in, state_in_h, state_in_c], outputs=[logits, values, state_h, state_c], ) self.rnn_model.summary() @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] + state) return model_out, [h, c] @override(ModelV2) def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ] @override(ModelV2) def value_function(self): return tf.reshape(self._value_out, [-1]) class TorchMobileV2PlusRNNModel(TorchRNN, nn.Module): """A conv. + recurrent torch net example using a pre-trained MobileNet.""" def __init__( self, obs_space, action_space, num_outputs, model_config, name, cnn_shape ): TorchRNN.__init__( self, obs_space, action_space, num_outputs, model_config, name ) nn.Module.__init__(self) self.lstm_state_size = 16 self.cnn_shape = list(cnn_shape) self.visual_size_in = cnn_shape[0] * cnn_shape[1] * cnn_shape[2] # MobileNetV2 has a flat output of (1000,). self.visual_size_out = 1000 # Load the MobileNetV2 from torch.hub. self.cnn_model = torch.hub.load( "pytorch/vision:v0.6.0", "mobilenet_v2", pretrained=True ) self.lstm = nn.LSTM( self.visual_size_out, self.lstm_state_size, batch_first=True ) # Postprocess LSTM output with another hidden layer and compute values. self.logits = SlimFC(self.lstm_state_size, self.num_outputs) self.value_branch = SlimFC(self.lstm_state_size, 1) # Holds the current "base" output (before logits layer). self._features = None @override(TorchRNN) def forward_rnn(self, inputs, state, seq_lens): # Create image dims. vision_in = torch.reshape(inputs, [-1] + self.cnn_shape) vision_out = self.cnn_model(vision_in) # Flatten. vision_out_time_ranked = torch.reshape( vision_out, [inputs.shape[0], inputs.shape[1], vision_out.shape[-1]] ) if len(state[0].shape) == 2: state[0] = state[0].unsqueeze(0) state[1] = state[1].unsqueeze(0) # Forward through LSTM. self._features, [h, c] = self.lstm(vision_out_time_ranked, state) # Forward LSTM out through logits layer and value layer. logits = self.logits(self._features) return logits, [h.squeeze(0), c.squeeze(0)] @override(ModelV2) def get_initial_state(self): # Place hidden states on same device as model. h = [ list(self.cnn_model.modules())[-1] .weight.new(1, self.lstm_state_size) .zero_() .squeeze(0), list(self.cnn_model.modules())[-1] .weight.new(1, self.lstm_state_size) .zero_() .squeeze(0), ] return h @override(ModelV2) def value_function(self): assert self._features is not None, "must call forward() first" return torch.reshape(self.value_branch(self._features), [-1])