ray/rllib/examples/models/rnn_model.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

141 lines
5 KiB
Python

import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
class RNNModel(RecurrentNetwork):
"""Example of using the Keras functional API to define a RNN model."""
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
hiddens_size=256,
cell_size=64,
):
super(RNNModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name
)
self.cell_size = cell_size
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]), name="inputs"
)
state_in_h = tf.keras.layers.Input(shape=(cell_size,), name="h")
state_in_c = tf.keras.layers.Input(shape=(cell_size,), name="c")
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
# Preprocess observation with a hidden layer and send to LSTM cell
dense1 = tf.keras.layers.Dense(
hiddens_size, activation=tf.nn.relu, name="dense1"
)(input_layer)
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True, name="lstm"
)(
inputs=dense1,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c],
)
# Postprocess LSTM output with another hidden layer and compute values
logits = tf.keras.layers.Dense(
self.num_outputs, activation=tf.keras.activations.linear, name="logits"
)(lstm_out)
values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c],
)
self.rnn_model.summary()
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] + state)
return model_out, [h, c]
@override(ModelV2)
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
class TorchRNNModel(TorchRNN, nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
fc_size=64,
lstm_state_size=256,
):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self.obs_size = get_preprocessor(obs_space)(obs_space).size
self.fc_size = fc_size
self.lstm_state_size = lstm_state_size
# Build the Module from fc + LSTM + 2xfc (action + value outs).
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
# Holds the current "base" output (before logits layer).
self._features = None
@override(ModelV2)
def get_initial_state(self):
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
# View API is supported across all of RLlib.
# Place hidden states on same device as model.
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
]
return h
@override(ModelV2)
def value_function(self):
assert self._features is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._features), [-1])
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
Returns:
NN Outputs (B x T x ...) as sequence.
The state batches as a List of two items (c- and h-states).
"""
x = nn.functional.relu(self.fc1(inputs))
self._features, [h, c] = self.lstm(
x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
)
action_out = self.action_branch(self._features)
return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]