mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
94 lines
3.6 KiB
Python
94 lines
3.6 KiB
Python
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
tf = try_import_tf()
|
|
|
|
|
|
@DeveloperAPI
|
|
class RecurrentTFModelV2(TFModelV2):
|
|
"""Helper class to simplify implementing RNN models with TFModelV2.
|
|
|
|
Instead of implementing forward(), you can implement forward_rnn() which
|
|
takes batches with the time dimension added already.
|
|
|
|
Here is an example implementation for a subclass
|
|
``MyRNNClass(RecurrentTFModelV2)``::
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(MyModelClass, self).__init__(*args, **kwargs)
|
|
cell_size = 256
|
|
|
|
# Define input layers
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(None, obs_space.shape[0]))
|
|
state_in_h = tf.keras.layers.Input(shape=(256, ))
|
|
state_in_c = tf.keras.layers.Input(shape=(256, ))
|
|
seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
|
|
|
|
# Send to LSTM cell
|
|
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
|
|
cell_size, return_sequences=True, return_state=True,
|
|
name="lstm")(
|
|
inputs=input_layer,
|
|
mask=tf.sequence_mask(seq_in),
|
|
initial_state=[state_in_h, state_in_c])
|
|
output_layer = tf.keras.layers.Dense(...)(lstm_out)
|
|
|
|
# Create the RNN model
|
|
self.rnn_model = tf.keras.Model(
|
|
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
|
outputs=[output_layer, state_h, state_c])
|
|
self.register_variables(self.rnn_model.variables)
|
|
self.rnn_model.summary()
|
|
"""
|
|
|
|
@override(ModelV2)
|
|
def forward(self, input_dict, state, seq_lens):
|
|
"""Adds time dimension to batch before sending inputs to forward_rnn().
|
|
|
|
You should implement forward_rnn() in your subclass."""
|
|
output, new_state = self.forward_rnn(
|
|
add_time_dimension(
|
|
input_dict["obs_flat"], seq_lens, framework="tf"), state,
|
|
seq_lens)
|
|
return tf.reshape(output, [-1, self.num_outputs]), new_state
|
|
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
Arguments:
|
|
inputs (dict): observation tensor with shape [B, T, obs_size].
|
|
state (list): list of state tensors, each with shape [B, T, size].
|
|
seq_lens (Tensor): 1d tensor holding input sequence lengths.
|
|
|
|
Returns:
|
|
(outputs, new_state): The model output tensor of shape
|
|
[B, T, num_outputs] and the list of new state tensors each with
|
|
shape [B, size].
|
|
|
|
Sample implementation for the ``MyRNNClass`` example::
|
|
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
|
|
return model_out, [h, c]
|
|
"""
|
|
raise NotImplementedError("You must implement this for a RNN model")
|
|
|
|
def get_initial_state(self):
|
|
"""Get the initial recurrent state values for the model.
|
|
|
|
Returns:
|
|
list of np.array objects, if any
|
|
|
|
Sample implementation for the ``MyRNNClass`` example::
|
|
|
|
def get_initial_state(self):
|
|
return [
|
|
np.zeros(self.cell_size, np.float32),
|
|
np.zeros(self.cell_size, np.float32),
|
|
]
|
|
"""
|
|
raise NotImplementedError("You must implement this for a RNN model")
|