mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
458 lines
17 KiB
Python
458 lines
17 KiB
Python
import numpy as np
|
|
import gym
|
|
from gym.spaces import Box, Discrete, MultiDiscrete
|
|
import logging
|
|
import tree # pip install dm_tree
|
|
from typing import Dict, List, Optional, Type, Tuple
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.policy.rnn_sequencing import add_time_dimension
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
|
from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@DeveloperAPI
|
|
class RecurrentNetwork(TFModelV2):
|
|
"""Helper class to simplify implementing RNN models with TFModelV2.
|
|
|
|
Instead of implementing forward(), you can implement forward_rnn() which
|
|
takes batches with the time dimension added already.
|
|
|
|
Here is an example implementation for a subclass
|
|
``MyRNNClass(RecurrentNetwork)``::
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(MyModelClass, self).__init__(*args, **kwargs)
|
|
cell_size = 256
|
|
|
|
# Define input layers
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(None, obs_space.shape[0]))
|
|
state_in_h = tf.keras.layers.Input(shape=(256, ))
|
|
state_in_c = tf.keras.layers.Input(shape=(256, ))
|
|
seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
|
|
|
|
# Send to LSTM cell
|
|
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
|
|
cell_size, return_sequences=True, return_state=True,
|
|
name="lstm")(
|
|
inputs=input_layer,
|
|
mask=tf.sequence_mask(seq_in),
|
|
initial_state=[state_in_h, state_in_c])
|
|
output_layer = tf.keras.layers.Dense(...)(lstm_out)
|
|
|
|
# Create the RNN model
|
|
self.rnn_model = tf.keras.Model(
|
|
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
|
outputs=[output_layer, state_h, state_c])
|
|
self.rnn_model.summary()
|
|
"""
|
|
|
|
@override(ModelV2)
|
|
def forward(
|
|
self,
|
|
input_dict: Dict[str, TensorType],
|
|
state: List[TensorType],
|
|
seq_lens: TensorType,
|
|
) -> Tuple[TensorType, List[TensorType]]:
|
|
"""Adds time dimension to batch before sending inputs to forward_rnn().
|
|
|
|
You should implement forward_rnn() in your subclass."""
|
|
assert seq_lens is not None
|
|
flat_inputs = input_dict["obs_flat"]
|
|
inputs = add_time_dimension(
|
|
padded_inputs=flat_inputs, seq_lens=seq_lens, framework="tf"
|
|
)
|
|
output, new_state = self.forward_rnn(
|
|
inputs,
|
|
state,
|
|
seq_lens,
|
|
)
|
|
return tf.reshape(output, [-1, self.num_outputs]), new_state
|
|
|
|
def forward_rnn(
|
|
self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
|
|
) -> Tuple[TensorType, List[TensorType]]:
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
Args:
|
|
inputs: observation tensor with shape [B, T, obs_size].
|
|
state: list of state tensors, each with shape [B, T, size].
|
|
seq_lens: 1d tensor holding input sequence lengths.
|
|
|
|
Returns:
|
|
(outputs, new_state): The model output tensor of shape
|
|
[B, T, num_outputs] and the list of new state tensors each with
|
|
shape [B, size].
|
|
|
|
Sample implementation for the ``MyRNNClass`` example::
|
|
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
|
|
return model_out, [h, c]
|
|
"""
|
|
raise NotImplementedError("You must implement this for a RNN model")
|
|
|
|
def get_initial_state(self) -> List[TensorType]:
|
|
"""Get the initial recurrent state values for the model.
|
|
|
|
Returns:
|
|
list of np.array objects, if any
|
|
|
|
Sample implementation for the ``MyRNNClass`` example::
|
|
|
|
def get_initial_state(self):
|
|
return [
|
|
np.zeros(self.cell_size, np.float32),
|
|
np.zeros(self.cell_size, np.float32),
|
|
]
|
|
"""
|
|
raise NotImplementedError("You must implement this for a RNN model")
|
|
|
|
|
|
@DeveloperAPI
|
|
class LSTMWrapper(RecurrentNetwork):
|
|
"""An LSTM wrapper serving as an interface for ModelV2s that set use_lstm."""
|
|
|
|
def __init__(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
num_outputs: int,
|
|
model_config: ModelConfigDict,
|
|
name: str,
|
|
):
|
|
|
|
super(LSTMWrapper, self).__init__(
|
|
obs_space, action_space, None, model_config, name
|
|
)
|
|
# At this point, self.num_outputs is the number of nodes coming
|
|
# from the wrapped (underlying) model. In other words, self.num_outputs
|
|
# is the input size for the LSTM layer.
|
|
# If None, set it to the observation space.
|
|
if self.num_outputs is None:
|
|
self.num_outputs = int(np.product(self.obs_space.shape))
|
|
|
|
self.cell_size = model_config["lstm_cell_size"]
|
|
self.use_prev_action = model_config["lstm_use_prev_action"]
|
|
self.use_prev_reward = model_config["lstm_use_prev_reward"]
|
|
|
|
self.action_space_struct = get_base_struct_from_space(self.action_space)
|
|
self.action_dim = 0
|
|
|
|
for space in tree.flatten(self.action_space_struct):
|
|
if isinstance(space, Discrete):
|
|
self.action_dim += space.n
|
|
elif isinstance(space, MultiDiscrete):
|
|
self.action_dim += np.sum(space.nvec)
|
|
elif space.shape is not None:
|
|
self.action_dim += int(np.product(space.shape))
|
|
else:
|
|
self.action_dim += int(len(space))
|
|
|
|
# Add prev-action/reward nodes to input to LSTM.
|
|
if self.use_prev_action:
|
|
self.num_outputs += self.action_dim
|
|
if self.use_prev_reward:
|
|
self.num_outputs += 1
|
|
|
|
# Define input layers.
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(None, self.num_outputs), name="inputs"
|
|
)
|
|
|
|
# Set self.num_outputs to the number of output nodes desired by the
|
|
# caller of this constructor.
|
|
self.num_outputs = num_outputs
|
|
|
|
state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h")
|
|
state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c")
|
|
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
|
|
|
|
# Preprocess observation with a hidden layer and send to LSTM cell
|
|
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
|
|
self.cell_size, return_sequences=True, return_state=True, name="lstm"
|
|
)(
|
|
inputs=input_layer,
|
|
mask=tf.sequence_mask(seq_in),
|
|
initial_state=[state_in_h, state_in_c],
|
|
)
|
|
|
|
# Postprocess LSTM output with another hidden layer and compute values
|
|
logits = tf.keras.layers.Dense(
|
|
self.num_outputs, activation=tf.keras.activations.linear, name="logits"
|
|
)(lstm_out)
|
|
values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)
|
|
|
|
# Create the RNN model
|
|
self._rnn_model = tf.keras.Model(
|
|
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
|
outputs=[logits, values, state_h, state_c],
|
|
)
|
|
# Print out model summary in INFO logging mode.
|
|
if logger.isEnabledFor(logging.INFO):
|
|
self._rnn_model.summary()
|
|
|
|
# Add prev-a/r to this model's view, if required.
|
|
if model_config["lstm_use_prev_action"]:
|
|
self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
|
|
SampleBatch.ACTIONS, space=self.action_space, shift=-1
|
|
)
|
|
if model_config["lstm_use_prev_reward"]:
|
|
self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
|
|
SampleBatch.REWARDS, shift=-1
|
|
)
|
|
|
|
@override(RecurrentNetwork)
|
|
def forward(
|
|
self,
|
|
input_dict: Dict[str, TensorType],
|
|
state: List[TensorType],
|
|
seq_lens: TensorType,
|
|
) -> Tuple[TensorType, List[TensorType]]:
|
|
assert seq_lens is not None
|
|
# Push obs through "unwrapped" net's `forward()` first.
|
|
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
|
|
|
|
# Concat. prev-action/reward if required.
|
|
prev_a_r = []
|
|
|
|
# Prev actions.
|
|
if self.model_config["lstm_use_prev_action"]:
|
|
prev_a = input_dict[SampleBatch.PREV_ACTIONS]
|
|
# If actions are not processed yet (in their original form as
|
|
# have been sent to environment):
|
|
# Flatten/one-hot into 1D array.
|
|
if self.model_config["_disable_action_flattening"]:
|
|
prev_a_r.append(
|
|
flatten_inputs_to_1d_tensor(
|
|
prev_a,
|
|
spaces_struct=self.action_space_struct,
|
|
time_axis=False,
|
|
)
|
|
)
|
|
# If actions are already flattened (but not one-hot'd yet!),
|
|
# one-hot discrete/multi-discrete actions here.
|
|
else:
|
|
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
|
|
prev_a = one_hot(prev_a, self.action_space)
|
|
prev_a_r.append(
|
|
tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
|
|
)
|
|
# Prev rewards.
|
|
if self.model_config["lstm_use_prev_reward"]:
|
|
prev_a_r.append(
|
|
tf.reshape(
|
|
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
|
|
)
|
|
)
|
|
|
|
# Concat prev. actions + rewards to the "main" input.
|
|
if prev_a_r:
|
|
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
|
|
|
|
# Push everything through our LSTM.
|
|
input_dict["obs_flat"] = wrapped_out
|
|
return super().forward(input_dict, state, seq_lens)
|
|
|
|
@override(RecurrentNetwork)
|
|
def forward_rnn(
|
|
self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
|
|
) -> Tuple[TensorType, List[TensorType]]:
|
|
model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state)
|
|
return model_out, [h, c]
|
|
|
|
@override(ModelV2)
|
|
def get_initial_state(self) -> List[np.ndarray]:
|
|
return [
|
|
np.zeros(self.cell_size, np.float32),
|
|
np.zeros(self.cell_size, np.float32),
|
|
]
|
|
|
|
@override(ModelV2)
|
|
def value_function(self) -> TensorType:
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
@DeveloperAPI
|
|
class Keras_LSTMWrapper(tf.keras.Model if tf else object):
|
|
"""A tf keras auto-LSTM wrapper used when `use_lstm`=True."""
|
|
|
|
def __init__(
|
|
self,
|
|
input_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
num_outputs: Optional[int] = None,
|
|
*,
|
|
name: str,
|
|
wrapped_cls: Type["tf.keras.Model"],
|
|
max_seq_len: int = 20,
|
|
lstm_cell_size: int = 256,
|
|
lstm_use_prev_action: bool = False,
|
|
lstm_use_prev_reward: bool = False,
|
|
**kwargs,
|
|
):
|
|
|
|
super().__init__(name=name)
|
|
self.wrapped_keras_model = wrapped_cls(
|
|
input_space, action_space, None, name="wrapped_" + name, **kwargs
|
|
)
|
|
|
|
self.action_space = action_space
|
|
self.max_seq_len = max_seq_len
|
|
|
|
# Guess the number of outputs for the wrapped model by looking
|
|
# at its first output's shape.
|
|
# This will be the input size for the LSTM layer (plus
|
|
# maybe prev-actions/rewards).
|
|
# If no layers in the wrapped model, set it to the
|
|
# observation space.
|
|
if self.wrapped_keras_model.layers:
|
|
assert self.wrapped_keras_model.layers[-1].outputs
|
|
assert len(self.wrapped_keras_model.layers[-1].outputs[0].shape) == 2
|
|
wrapped_num_outputs = int(
|
|
self.wrapped_keras_model.layers[-1].outputs[0].shape[1]
|
|
)
|
|
else:
|
|
wrapped_num_outputs = int(np.product(self.obs_space.shape))
|
|
|
|
self.lstm_cell_size = lstm_cell_size
|
|
self.lstm_use_prev_action = lstm_use_prev_action
|
|
self.lstm_use_prev_reward = lstm_use_prev_reward
|
|
|
|
if isinstance(self.action_space, Discrete):
|
|
self.action_dim = self.action_space.n
|
|
elif isinstance(self.action_space, MultiDiscrete):
|
|
self.action_dim = np.sum(self.action_space.nvec)
|
|
elif self.action_space.shape is not None:
|
|
self.action_dim = int(np.product(self.action_space.shape))
|
|
else:
|
|
self.action_dim = int(len(self.action_space))
|
|
|
|
# Add prev-action/reward nodes to input to LSTM.
|
|
if self.lstm_use_prev_action:
|
|
wrapped_num_outputs += self.action_dim
|
|
if self.lstm_use_prev_reward:
|
|
wrapped_num_outputs += 1
|
|
|
|
# Define input layers.
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(None, wrapped_num_outputs), name="inputs"
|
|
)
|
|
|
|
state_in_h = tf.keras.layers.Input(shape=(self.lstm_cell_size,), name="h")
|
|
state_in_c = tf.keras.layers.Input(shape=(self.lstm_cell_size,), name="c")
|
|
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
|
|
|
|
# Preprocess observation with a hidden layer and send to LSTM cell
|
|
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
|
|
self.lstm_cell_size, return_sequences=True, return_state=True, name="lstm"
|
|
)(
|
|
inputs=input_layer,
|
|
mask=tf.sequence_mask(seq_in),
|
|
initial_state=[state_in_h, state_in_c],
|
|
)
|
|
|
|
# Postprocess LSTM output with another hidden layer
|
|
# if num_outputs not None.
|
|
if num_outputs:
|
|
logits = tf.keras.layers.Dense(
|
|
num_outputs, activation=tf.keras.activations.linear, name="logits"
|
|
)(lstm_out)
|
|
else:
|
|
logits = lstm_out
|
|
# Compute values.
|
|
values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)
|
|
|
|
# Create the RNN model
|
|
self._rnn_model = tf.keras.Model(
|
|
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
|
outputs=[logits, values, state_h, state_c],
|
|
)
|
|
|
|
# Use view-requirements of wrapped model and add own
|
|
# requirements.
|
|
self.view_requirements = getattr(
|
|
self.wrapped_keras_model,
|
|
"view_requirements",
|
|
{SampleBatch.OBS: ViewRequirement(space=input_space)},
|
|
)
|
|
|
|
# Add prev-a/r to this model's view, if required.
|
|
if self.lstm_use_prev_action:
|
|
self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
|
|
SampleBatch.ACTIONS, space=self.action_space, shift=-1
|
|
)
|
|
if self.lstm_use_prev_reward:
|
|
self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
|
|
SampleBatch.REWARDS, shift=-1
|
|
)
|
|
|
|
# Internal states view requirements.
|
|
for i in range(2):
|
|
space = Box(-1.0, 1.0, shape=(self.lstm_cell_size,))
|
|
self.view_requirements["state_in_{}".format(i)] = ViewRequirement(
|
|
"state_out_{}".format(i),
|
|
shift=-1,
|
|
used_for_compute_actions=True,
|
|
batch_repeat_value=max_seq_len,
|
|
space=space,
|
|
)
|
|
self.view_requirements["state_out_{}".format(i)] = ViewRequirement(
|
|
space=space, used_for_training=True
|
|
)
|
|
|
|
def call(
|
|
self, input_dict: SampleBatch
|
|
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
|
assert input_dict.get(SampleBatch.SEQ_LENS) is not None
|
|
# Push obs through underlying (wrapped) model first.
|
|
wrapped_out, _, _ = self.wrapped_keras_model(input_dict)
|
|
|
|
# Concat. prev-action/reward if required.
|
|
prev_a_r = []
|
|
if self.lstm_use_prev_action:
|
|
prev_a = input_dict[SampleBatch.PREV_ACTIONS]
|
|
if isinstance(self.action_space, (Discrete, MultiDiscrete)):
|
|
prev_a = one_hot(prev_a, self.action_space)
|
|
prev_a_r.append(
|
|
tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
|
|
)
|
|
if self.lstm_use_prev_reward:
|
|
prev_a_r.append(
|
|
tf.reshape(
|
|
tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
|
|
)
|
|
)
|
|
|
|
if prev_a_r:
|
|
wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
|
|
|
|
wrapped_out_plus_time_dim = add_time_dimension(
|
|
wrapped_out, seq_lens=input_dict[SampleBatch.SEQ_LENS], framework="tf"
|
|
)
|
|
model_out, value_out, h, c = self._rnn_model(
|
|
[
|
|
wrapped_out_plus_time_dim,
|
|
input_dict[SampleBatch.SEQ_LENS],
|
|
input_dict["state_in_0"],
|
|
input_dict["state_in_1"],
|
|
]
|
|
)
|
|
model_out_no_time_dim = tf.reshape(
|
|
model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0)
|
|
)
|
|
return (
|
|
model_out_no_time_dim,
|
|
[h, c],
|
|
{SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])},
|
|
)
|