mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
136 lines
5.2 KiB
Python
136 lines
5.2 KiB
Python
import numpy as np
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.preprocessors import get_preprocessor
|
|
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
|
|
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class RNNModel(RecurrentNetwork):
|
|
"""Example of using the Keras functional API to define a RNN model."""
|
|
|
|
def __init__(self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
hiddens_size=256,
|
|
cell_size=64):
|
|
super(RNNModel, self).__init__(obs_space, action_space, num_outputs,
|
|
model_config, name)
|
|
self.cell_size = cell_size
|
|
|
|
# Define input layers
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(None, obs_space.shape[0]), name="inputs")
|
|
state_in_h = tf.keras.layers.Input(shape=(cell_size, ), name="h")
|
|
state_in_c = tf.keras.layers.Input(shape=(cell_size, ), name="c")
|
|
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
|
|
|
|
# Preprocess observation with a hidden layer and send to LSTM cell
|
|
dense1 = tf.keras.layers.Dense(
|
|
hiddens_size, activation=tf.nn.relu, name="dense1")(input_layer)
|
|
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
|
|
cell_size, return_sequences=True, return_state=True, name="lstm")(
|
|
inputs=dense1,
|
|
mask=tf.sequence_mask(seq_in),
|
|
initial_state=[state_in_h, state_in_c])
|
|
|
|
# Postprocess LSTM output with another hidden layer and compute values
|
|
logits = tf.keras.layers.Dense(
|
|
self.num_outputs,
|
|
activation=tf.keras.activations.linear,
|
|
name="logits")(lstm_out)
|
|
values = tf.keras.layers.Dense(
|
|
1, activation=None, name="values")(lstm_out)
|
|
|
|
# Create the RNN model
|
|
self.rnn_model = tf.keras.Model(
|
|
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
|
outputs=[logits, values, state_h, state_c])
|
|
self.rnn_model.summary()
|
|
|
|
@override(RecurrentNetwork)
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] +
|
|
state)
|
|
return model_out, [h, c]
|
|
|
|
@override(ModelV2)
|
|
def get_initial_state(self):
|
|
return [
|
|
np.zeros(self.cell_size, np.float32),
|
|
np.zeros(self.cell_size, np.float32),
|
|
]
|
|
|
|
@override(ModelV2)
|
|
def value_function(self):
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
class TorchRNNModel(TorchRNN, nn.Module):
|
|
def __init__(self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
fc_size=64,
|
|
lstm_state_size=256):
|
|
nn.Module.__init__(self)
|
|
super().__init__(obs_space, action_space, num_outputs, model_config,
|
|
name)
|
|
|
|
self.obs_size = get_preprocessor(obs_space)(obs_space).size
|
|
self.fc_size = fc_size
|
|
self.lstm_state_size = lstm_state_size
|
|
|
|
# Build the Module from fc + LSTM + 2xfc (action + value outs).
|
|
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
|
|
self.lstm = nn.LSTM(
|
|
self.fc_size, self.lstm_state_size, batch_first=True)
|
|
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
|
|
self.value_branch = nn.Linear(self.lstm_state_size, 1)
|
|
# Holds the current "base" output (before logits layer).
|
|
self._features = None
|
|
|
|
@override(ModelV2)
|
|
def get_initial_state(self):
|
|
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
|
|
# View API is supported across all of RLlib.
|
|
# Place hidden states on same device as model.
|
|
h = [
|
|
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
|
|
self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
|
|
]
|
|
return h
|
|
|
|
@override(ModelV2)
|
|
def value_function(self):
|
|
assert self._features is not None, "must call forward() first"
|
|
return torch.reshape(self.value_branch(self._features), [-1])
|
|
|
|
@override(TorchRNN)
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
|
|
|
|
Returns the resulting outputs as a sequence (B x T x ...).
|
|
Values are stored in self._cur_value in simple (B) shape (where B
|
|
contains both the B and T dims!).
|
|
|
|
Returns:
|
|
NN Outputs (B x T x ...) as sequence.
|
|
The state batches as a List of two items (c- and h-states).
|
|
"""
|
|
x = nn.functional.relu(self.fc1(inputs))
|
|
self._features, [h, c] = self.lstm(
|
|
x, [torch.unsqueeze(state[0], 0),
|
|
torch.unsqueeze(state[1], 0)])
|
|
action_out = self.action_branch(self._features)
|
|
return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
|