mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
377 lines
13 KiB
Python
377 lines
13 KiB
Python
from collections import OrderedDict
|
|
import gym
|
|
|
|
from ray.rllib.models.preprocessors import get_preprocessor
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
tf = try_import_tf()
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
@PublicAPI
|
|
class ModelV2:
|
|
"""Defines a Keras-style abstract network model for use with RLlib.
|
|
|
|
Custom models should extend either TFModelV2 or TorchModelV2 instead of
|
|
this class directly.
|
|
|
|
Data flow:
|
|
obs -> forward() -> model_out
|
|
value_function() -> V(s)
|
|
|
|
Attributes:
|
|
obs_space (Space): observation space of the target gym env. This
|
|
may have an `original_space` attribute that specifies how to
|
|
unflatten the tensor into a ragged tensor.
|
|
action_space (Space): action space of the target gym env
|
|
num_outputs (int): number of output units of the model
|
|
model_config (dict): config for the model, documented in ModelCatalog
|
|
name (str): name (scope) for the model
|
|
framework (str): either "tf" or "torch"
|
|
"""
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name, framework):
|
|
"""Initialize the model.
|
|
|
|
This method should create any variables used by the model.
|
|
"""
|
|
|
|
self.obs_space = obs_space
|
|
self.action_space = action_space
|
|
self.num_outputs = num_outputs
|
|
self.model_config = model_config
|
|
self.name = name or "default_model"
|
|
self.framework = framework
|
|
self._last_output = None
|
|
|
|
def get_initial_state(self):
|
|
"""Get the initial recurrent state values for the model.
|
|
|
|
Returns:
|
|
List[np.ndarray]: List of np.array objects containing the initial
|
|
hidden state of an RNN, if applicable.
|
|
|
|
Examples:
|
|
>>> def get_initial_state(self):
|
|
>>> return [
|
|
>>> np.zeros(self.cell_size, np.float32),
|
|
>>> np.zeros(self.cell_size, np.float32),
|
|
>>> ]
|
|
"""
|
|
return []
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
|
__call__ before being passed to forward(). To access the flattened
|
|
observation tensor, refer to input_dict["obs_flat"].
|
|
|
|
This method can be called any number of times. In eager execution,
|
|
each call to forward() will eagerly evaluate the model. In symbolic
|
|
execution, each call to forward creates a computation graph that
|
|
operates over the variables of this model (i.e., shares weights).
|
|
|
|
Custom models should override this instead of __call__.
|
|
|
|
Args:
|
|
input_dict (dict): dictionary of input tensors, including "obs",
|
|
"obs_flat", "prev_action", "prev_reward", "is_training"
|
|
state (list): list of state tensors with sizes matching those
|
|
returned by get_initial_state + the batch dimension
|
|
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
|
|
|
Returns:
|
|
(outputs, state): The model output tensor of size
|
|
[BATCH, num_outputs]
|
|
|
|
Examples:
|
|
>>> def forward(self, input_dict, state, seq_lens):
|
|
>>> model_out, self._value_out = self.base_model(
|
|
... input_dict["obs"])
|
|
>>> return model_out, state
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def value_function(self):
|
|
"""Returns the value function output for the most recent forward pass.
|
|
|
|
Note that a `forward` call has to be performed first, before this
|
|
methods can return anything and thus that calling this method does not
|
|
cause an extra forward pass through the network.
|
|
|
|
Returns:
|
|
value estimate tensor of shape [BATCH].
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def custom_loss(self, policy_loss, loss_inputs):
|
|
"""Override to customize the loss function used to optimize this model.
|
|
|
|
This can be used to incorporate self-supervised losses (by defining
|
|
a loss over existing input and output tensors of this model), and
|
|
supervised losses (by defining losses over a variable-sharing copy of
|
|
this model's layers).
|
|
|
|
You can find an runnable example in examples/custom_loss.py.
|
|
|
|
Arguments:
|
|
policy_loss (Union[List[Tensor],Tensor]): List of or single policy
|
|
loss(es) from the policy.
|
|
loss_inputs (dict): map of input placeholders for rollout data.
|
|
|
|
Returns:
|
|
Union[List[Tensor],Tensor]: List of or scalar tensor for the
|
|
customized loss(es) for this model.
|
|
"""
|
|
return policy_loss
|
|
|
|
def metrics(self):
|
|
"""Override to return custom metrics from your model.
|
|
|
|
The stats will be reported as part of the learner stats, i.e.,
|
|
info:
|
|
learner:
|
|
model:
|
|
key1: metric1
|
|
key2: metric2
|
|
|
|
Returns:
|
|
Dict of string keys to scalar tensors.
|
|
"""
|
|
return {}
|
|
|
|
def __call__(self, input_dict, state=None, seq_lens=None):
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
This is the method used by RLlib to execute the forward pass. It calls
|
|
forward() internally after unpacking nested observation tensors.
|
|
|
|
Custom models should override forward() instead of __call__.
|
|
|
|
Arguments:
|
|
input_dict (dict): dictionary of input tensors, including "obs",
|
|
"prev_action", "prev_reward", "is_training"
|
|
state (list): list of state tensors with sizes matching those
|
|
returned by get_initial_state + the batch dimension
|
|
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
|
|
|
Returns:
|
|
(outputs, state): The model output tensor of size
|
|
[BATCH, output_spec.size] or a list of tensors corresponding to
|
|
output_spec.shape_list, and a list of state tensors of
|
|
[BATCH, state_size_i].
|
|
"""
|
|
|
|
restored = input_dict.copy()
|
|
restored["obs"] = restore_original_dimensions(
|
|
input_dict["obs"], self.obs_space, self.framework)
|
|
if len(input_dict["obs"].shape) > 2:
|
|
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
|
|
else:
|
|
restored["obs_flat"] = input_dict["obs"]
|
|
with self.context():
|
|
res = self.forward(restored, state or [], seq_lens)
|
|
if ((not isinstance(res, list) and not isinstance(res, tuple))
|
|
or len(res) != 2):
|
|
raise ValueError(
|
|
"forward() must return a tuple of (output, state) tensors, "
|
|
"got {}".format(res))
|
|
outputs, state = res
|
|
|
|
try:
|
|
shape = outputs.shape
|
|
except AttributeError:
|
|
raise ValueError("Output is not a tensor: {}".format(outputs))
|
|
else:
|
|
if len(shape) != 2 or shape[1] != self.num_outputs:
|
|
raise ValueError(
|
|
"Expected output shape of [None, {}], got {}".format(
|
|
self.num_outputs, shape))
|
|
if not isinstance(state, list):
|
|
raise ValueError("State output is not a list: {}".format(state))
|
|
|
|
self._last_output = outputs
|
|
return outputs, state
|
|
|
|
def from_batch(self, train_batch, is_training=True):
|
|
"""Convenience function that calls this model with a tensor batch.
|
|
|
|
All this does is unpack the tensor batch to call this model with the
|
|
right input dict, state, and seq len arguments.
|
|
"""
|
|
|
|
input_dict = {
|
|
"obs": train_batch[SampleBatch.CUR_OBS],
|
|
"is_training": is_training,
|
|
}
|
|
if SampleBatch.PREV_ACTIONS in train_batch:
|
|
input_dict["prev_actions"] = train_batch[SampleBatch.PREV_ACTIONS]
|
|
if SampleBatch.PREV_REWARDS in train_batch:
|
|
input_dict["prev_rewards"] = train_batch[SampleBatch.PREV_REWARDS]
|
|
states = []
|
|
i = 0
|
|
while "state_in_{}".format(i) in train_batch:
|
|
states.append(train_batch["state_in_{}".format(i)])
|
|
i += 1
|
|
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
|
|
|
def import_from_h5(self, h5_file):
|
|
"""Imports weights from an h5 file.
|
|
|
|
Args:
|
|
h5_file (str): The h5 file name to import weights from.
|
|
|
|
Example:
|
|
>>> trainer = MyTrainer()
|
|
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
|
|
>>> for _ in range(10):
|
|
>>> trainer.train()
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def last_output(self):
|
|
"""Returns the last output returned from calling the model."""
|
|
return self._last_output
|
|
|
|
def context(self):
|
|
"""Returns a contextmanager for the current forward pass."""
|
|
return NullContextManager()
|
|
|
|
def variables(self, as_dict=False):
|
|
"""Returns the list (or a dict) of variables for this model.
|
|
|
|
Args:
|
|
as_dict(bool): Whether variables should be returned as dict-values
|
|
(using descriptive keys).
|
|
|
|
Returns:
|
|
Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
|
|
True) of all variables of this ModelV2.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def trainable_variables(self, as_dict=False):
|
|
"""Returns the list of trainable variables for this model.
|
|
|
|
Args:
|
|
as_dict(bool): Whether variables should be returned as dict-values
|
|
(using descriptive keys).
|
|
|
|
Returns:
|
|
Union[List[any],Dict[str,any]]: The list (or dict if `as_dict` is
|
|
True) of all trainable (tf)/requires_grad (torch) variables
|
|
of this ModelV2.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class NullContextManager:
|
|
"""No-op context manager"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
pass
|
|
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
|
|
@DeveloperAPI
|
|
def flatten(obs, framework):
|
|
"""Flatten the given tensor."""
|
|
if framework == "tf":
|
|
return tf.layers.flatten(obs)
|
|
elif framework == "torch":
|
|
assert torch is not None
|
|
return torch.flatten(obs, start_dim=1)
|
|
else:
|
|
raise NotImplementedError("flatten", framework)
|
|
|
|
|
|
@DeveloperAPI
|
|
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
|
"""Unpacks Dict and Tuple space observations into their original form.
|
|
|
|
This is needed since we flatten Dict and Tuple observations in transit.
|
|
Before sending them to the model though, we should unflatten them into
|
|
Dicts or Tuples of tensors.
|
|
|
|
Arguments:
|
|
obs: The flattened observation tensor.
|
|
obs_space: The flattened obs space. If this has the `original_space`
|
|
attribute, we will unflatten the tensor to that shape.
|
|
tensorlib: The library used to unflatten (reshape) the array/tensor.
|
|
|
|
Returns:
|
|
single tensor or dict / tuple of tensors matching the original
|
|
observation space.
|
|
"""
|
|
|
|
if hasattr(obs_space, "original_space"):
|
|
if tensorlib == "tf":
|
|
tensorlib = tf
|
|
elif tensorlib == "torch":
|
|
assert torch is not None
|
|
tensorlib = torch
|
|
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
|
|
else:
|
|
return obs
|
|
|
|
|
|
# Cache of preprocessors, for if the user is calling unpack obs often.
|
|
_cache = {}
|
|
|
|
|
|
def _unpack_obs(obs, space, tensorlib=tf):
|
|
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
|
|
|
Arguments:
|
|
obs: The flattened observation tensor
|
|
space: The original space prior to flattening
|
|
tensorlib: The library used to unflatten (reshape) the array/tensor
|
|
"""
|
|
|
|
if (isinstance(space, gym.spaces.Dict)
|
|
or isinstance(space, gym.spaces.Tuple)):
|
|
if id(space) in _cache:
|
|
prep = _cache[id(space)]
|
|
else:
|
|
prep = get_preprocessor(space)(space)
|
|
# Make an attempt to cache the result, if enough space left.
|
|
if len(_cache) < 999:
|
|
_cache[id(space)] = prep
|
|
if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]:
|
|
raise ValueError(
|
|
"Expected flattened obs shape of [None, {}], got {}".format(
|
|
prep.shape[0], obs.shape))
|
|
assert len(prep.preprocessors) == len(space.spaces), \
|
|
(len(prep.preprocessors) == len(space.spaces))
|
|
offset = 0
|
|
if isinstance(space, gym.spaces.Tuple):
|
|
u = []
|
|
for p, v in zip(prep.preprocessors, space.spaces):
|
|
obs_slice = obs[:, offset:offset + p.size]
|
|
offset += p.size
|
|
u.append(
|
|
_unpack_obs(
|
|
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
|
v,
|
|
tensorlib=tensorlib))
|
|
else:
|
|
u = OrderedDict()
|
|
for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
|
|
obs_slice = obs[:, offset:offset + p.size]
|
|
offset += p.size
|
|
u[k] = _unpack_obs(
|
|
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
|
v,
|
|
tensorlib=tensorlib)
|
|
return u
|
|
else:
|
|
return obs
|