ray/rllib/models/tf/recurrent_tf_modelv2.py

94 lines
3.6 KiB
Python

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@DeveloperAPI
class RecurrentTFModelV2(TFModelV2):
"""Helper class to simplify implementing RNN models with TFModelV2.
Instead of implementing forward(), you can implement forward_rnn() which
takes batches with the time dimension added already.
Here is an example implementation for a subclass
``MyRNNClass(RecurrentTFModelV2)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
cell_size = 256
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]))
state_in_h = tf.keras.layers.Input(shape=(256, ))
state_in_c = tf.keras.layers.Input(shape=(256, ))
seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
# Send to LSTM cell
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True,
name="lstm")(
inputs=input_layer,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
output_layer = tf.keras.layers.Dense(...)(lstm_out)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[output_layer, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
"""
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
output, new_state = self.forward_rnn(
add_time_dimension(
input_dict["obs_flat"], seq_lens, framework="tf"), state,
seq_lens)
return tf.reshape(output, [-1, self.num_outputs]), new_state
def forward_rnn(self, inputs, state, seq_lens):
"""Call the model with the given input tensors and state.
Arguments:
inputs (dict): observation tensor with shape [B, T, obs_size].
state (list): list of state tensors, each with shape [B, T, size].
seq_lens (Tensor): 1d tensor holding input sequence lengths.
Returns:
(outputs, new_state): The model output tensor of shape
[B, T, num_outputs] and the list of new state tensors each with
shape [B, size].
Sample implementation for the ``MyRNNClass`` example::
def forward_rnn(self, inputs, state, seq_lens):
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
return model_out, [h, c]
"""
raise NotImplementedError("You must implement this for a RNN model")
def get_initial_state(self):
"""Get the initial recurrent state values for the model.
Returns:
list of np.array objects, if any
Sample implementation for the ``MyRNNClass`` example::
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
"""
raise NotImplementedError("You must implement this for a RNN model")