import numpy as np import pickle import ray from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.models.tf.recurrent_net import RecurrentNetwork from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf tf1, tf, tfv = try_import_tf() class SpyLayer(tf.keras.layers.Layer): """A keras Layer, which intercepts its inputs and stored them as pickled.""" output = np.array(0, dtype=np.int64) def __init__(self, num_outputs, **kwargs): super().__init__(**kwargs) self.dense = tf.keras.layers.Dense( units=num_outputs, kernel_initializer=normc_initializer(0.01) ) def call(self, inputs, **kwargs): """Does a forward pass through our Dense, but also intercepts inputs.""" del kwargs spy_fn = tf1.py_func( self.spy, [ inputs[0], # observations inputs[2], # seq_lens inputs[3], # h_in inputs[4], # c_in inputs[5], # h_out inputs[6], # c_out ], tf.int64, # Must match SpyLayer.output's type. stateful=True, ) # Compute outputs with tf1.control_dependencies([spy_fn]): return self.dense(inputs[1]) @staticmethod def spy(inputs, seq_lens, h_in, c_in, h_out, c_out): """The actual spy operation: Store inputs in internal_kv.""" if len(inputs) == 1: # don't capture inference inputs return SpyLayer.output # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "rnn_spy_in_{}".format(RNNSpyModel.capture_index), pickle.dumps( { "sequences": inputs, "seq_lens": seq_lens, "state_in": [h_in, c_in], "state_out": [h_out, c_out], } ), overwrite=True, ) RNNSpyModel.capture_index += 1 return SpyLayer.output class RNNSpyModel(RecurrentNetwork): capture_index = 0 cell_size = 3 def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.cell_size = RNNSpyModel.cell_size # Create a keras LSTM model. inputs = tf.keras.layers.Input(shape=(None,) + obs_space.shape, name="input") state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h") state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c") seq_lens = tf.keras.layers.Input(shape=(), name="seq_lens", dtype=tf.int32) lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM( self.cell_size, return_sequences=True, return_state=True, name="lstm" )( inputs=inputs, mask=tf.sequence_mask(seq_lens), initial_state=[state_in_h, state_in_c], ) logits = SpyLayer(num_outputs=self.num_outputs)( [ inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h, state_out_c, ] ) # Value branch. value_out = tf.keras.layers.Dense( units=1, kernel_initializer=normc_initializer(1.0) )(lstm_out) self.base_model = tf.keras.Model( [inputs, seq_lens, state_in_h, state_in_c], [logits, value_out, state_out_h, state_out_c], ) self.base_model.summary() @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): # Previously, a new class object was created during # deserialization and this `capture_index` # variable would be refreshed between class instantiations. # This behavior is no longer the case, so we manually refresh # the variable. RNNSpyModel.capture_index = 0 model_out, value_out, h, c = self.base_model( [inputs, seq_lens, state[0], state[1]] ) self._value_out = value_out return model_out, [h, c] @override(ModelV2) def value_function(self): return tf.reshape(self._value_out, [-1]) @override(ModelV2) def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ]