2019-08-23 02:21:11 -04:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-09-19 12:10:31 -07:00
|
|
|
from ray.rllib.models.model import restore_original_dimensions, flatten
|
2019-07-25 11:02:53 -07:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
|
2019-07-25 11:02:53 -07:00
|
|
|
@PublicAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class ModelV2:
|
2019-07-03 15:59:47 -07:00
|
|
|
"""Defines a Keras-style abstract network model for use with RLlib.
|
|
|
|
|
|
|
|
Custom models should extend either TFModelV2 or TorchModelV2 instead of
|
2019-07-24 13:55:55 -07:00
|
|
|
this class directly.
|
|
|
|
|
|
|
|
Data flow:
|
|
|
|
obs -> forward() -> model_out
|
|
|
|
value_function() -> V(s)
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
Attributes:
|
|
|
|
obs_space (Space): observation space of the target gym env. This
|
|
|
|
may have an `original_space` attribute that specifies how to
|
|
|
|
unflatten the tensor into a ragged tensor.
|
|
|
|
action_space (Space): action space of the target gym env
|
|
|
|
num_outputs (int): number of output units of the model
|
|
|
|
model_config (dict): config for the model, documented in ModelCatalog
|
|
|
|
name (str): name (scope) for the model
|
|
|
|
framework (str): either "tf" or "torch"
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
|
|
name, framework):
|
|
|
|
"""Initialize the model.
|
|
|
|
|
|
|
|
This method should create any variables used by the model.
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.obs_space = obs_space
|
|
|
|
self.action_space = action_space
|
|
|
|
self.num_outputs = num_outputs
|
|
|
|
self.model_config = model_config
|
|
|
|
self.name = name or "default_model"
|
|
|
|
self.framework = framework
|
2019-08-23 02:21:11 -04:00
|
|
|
self._last_output = None
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
def get_initial_state(self):
|
|
|
|
"""Get the initial recurrent state values for the model.
|
|
|
|
|
|
|
|
Returns:
|
2020-04-01 07:00:28 +02:00
|
|
|
List[np.ndarray]: List of np.array objects containing the initial
|
|
|
|
hidden state of an RNN, if applicable.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> def get_initial_state(self):
|
|
|
|
>>> return [
|
|
|
|
>>> np.zeros(self.cell_size, np.float32),
|
|
|
|
>>> np.zeros(self.cell_size, np.float32),
|
|
|
|
>>> ]
|
2019-07-03 15:59:47 -07:00
|
|
|
"""
|
|
|
|
return []
|
|
|
|
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
|
|
|
|
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
|
|
|
__call__ before being passed to forward(). To access the flattened
|
|
|
|
observation tensor, refer to input_dict["obs_flat"].
|
|
|
|
|
|
|
|
This method can be called any number of times. In eager execution,
|
|
|
|
each call to forward() will eagerly evaluate the model. In symbolic
|
|
|
|
execution, each call to forward creates a computation graph that
|
|
|
|
operates over the variables of this model (i.e., shares weights).
|
|
|
|
|
|
|
|
Custom models should override this instead of __call__.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
input_dict (dict): dictionary of input tensors, including "obs",
|
|
|
|
"obs_flat", "prev_action", "prev_reward", "is_training"
|
|
|
|
state (list): list of state tensors with sizes matching those
|
|
|
|
returned by get_initial_state + the batch dimension
|
|
|
|
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(outputs, state): The model output tensor of size
|
|
|
|
[BATCH, num_outputs]
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def value_function(self):
|
|
|
|
"""Return the value function estimate for the most recent forward pass.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
value estimate tensor of shape [BATCH].
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def custom_loss(self, policy_loss, loss_inputs):
|
|
|
|
"""Override to customize the loss function used to optimize this model.
|
|
|
|
|
|
|
|
This can be used to incorporate self-supervised losses (by defining
|
|
|
|
a loss over existing input and output tensors of this model), and
|
|
|
|
supervised losses (by defining losses over a variable-sharing copy of
|
|
|
|
this model's layers).
|
|
|
|
|
|
|
|
You can find an runnable example in examples/custom_loss.py.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
policy_loss (Tensor): scalar policy loss from the policy.
|
|
|
|
loss_inputs (dict): map of input placeholders for rollout data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Scalar tensor for the customized loss for this model.
|
|
|
|
"""
|
|
|
|
return policy_loss
|
|
|
|
|
|
|
|
def metrics(self):
|
|
|
|
"""Override to return custom metrics from your model.
|
|
|
|
|
|
|
|
The stats will be reported as part of the learner stats, i.e.,
|
|
|
|
info:
|
|
|
|
learner:
|
|
|
|
model:
|
|
|
|
key1: metric1
|
|
|
|
key2: metric2
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict of string keys to scalar tensors.
|
|
|
|
"""
|
|
|
|
return {}
|
|
|
|
|
2019-07-25 11:02:53 -07:00
|
|
|
def __call__(self, input_dict, state=None, seq_lens=None):
|
2019-07-03 15:59:47 -07:00
|
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
|
|
|
|
This is the method used by RLlib to execute the forward pass. It calls
|
|
|
|
forward() internally after unpacking nested observation tensors.
|
|
|
|
|
|
|
|
Custom models should override forward() instead of __call__.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
input_dict (dict): dictionary of input tensors, including "obs",
|
|
|
|
"prev_action", "prev_reward", "is_training"
|
|
|
|
state (list): list of state tensors with sizes matching those
|
|
|
|
returned by get_initial_state + the batch dimension
|
|
|
|
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(outputs, state): The model output tensor of size
|
|
|
|
[BATCH, output_spec.size] or a list of tensors corresponding to
|
|
|
|
output_spec.shape_list, and a list of state tensors of
|
|
|
|
[BATCH, state_size_i].
|
|
|
|
"""
|
|
|
|
|
|
|
|
restored = input_dict.copy()
|
|
|
|
restored["obs"] = restore_original_dimensions(
|
|
|
|
input_dict["obs"], self.obs_space, self.framework)
|
2019-09-19 12:10:31 -07:00
|
|
|
if len(input_dict["obs"].shape) > 2:
|
|
|
|
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
|
|
|
|
else:
|
|
|
|
restored["obs_flat"] = input_dict["obs"]
|
2019-09-07 11:50:18 -07:00
|
|
|
with self.context():
|
|
|
|
res = self.forward(restored, state or [], seq_lens)
|
2019-07-27 02:08:16 -07:00
|
|
|
if ((not isinstance(res, list) and not isinstance(res, tuple))
|
|
|
|
or len(res) != 2):
|
|
|
|
raise ValueError(
|
|
|
|
"forward() must return a tuple of (output, state) tensors, "
|
|
|
|
"got {}".format(res))
|
|
|
|
outputs, state = res
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
try:
|
|
|
|
shape = outputs.shape
|
|
|
|
except AttributeError:
|
|
|
|
raise ValueError("Output is not a tensor: {}".format(outputs))
|
|
|
|
else:
|
|
|
|
if len(shape) != 2 or shape[1] != self.num_outputs:
|
|
|
|
raise ValueError(
|
|
|
|
"Expected output shape of [None, {}], got {}".format(
|
|
|
|
self.num_outputs, shape))
|
|
|
|
if not isinstance(state, list):
|
|
|
|
raise ValueError("State output is not a list: {}".format(state))
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
self._last_output = outputs
|
2019-07-03 15:59:47 -07:00
|
|
|
return outputs, state
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
def from_batch(self, train_batch, is_training=True):
|
|
|
|
"""Convenience function that calls this model with a tensor batch.
|
|
|
|
|
|
|
|
All this does is unpack the tensor batch to call this model with the
|
|
|
|
right input dict, state, and seq len arguments.
|
|
|
|
"""
|
|
|
|
|
|
|
|
input_dict = {
|
|
|
|
"obs": train_batch[SampleBatch.CUR_OBS],
|
|
|
|
"is_training": is_training,
|
|
|
|
}
|
|
|
|
if SampleBatch.PREV_ACTIONS in train_batch:
|
|
|
|
input_dict["prev_actions"] = train_batch[SampleBatch.PREV_ACTIONS]
|
|
|
|
if SampleBatch.PREV_REWARDS in train_batch:
|
|
|
|
input_dict["prev_rewards"] = train_batch[SampleBatch.PREV_REWARDS]
|
|
|
|
states = []
|
|
|
|
i = 0
|
|
|
|
while "state_in_{}".format(i) in train_batch:
|
|
|
|
states.append(train_batch["state_in_{}".format(i)])
|
|
|
|
i += 1
|
|
|
|
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
|
|
|
|
2020-03-23 20:19:30 +01:00
|
|
|
def import_from_h5(self, h5_file):
|
|
|
|
"""Imports weights from an h5 file.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
h5_file (str): The h5 file name to import weights from.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> trainer = MyTrainer()
|
|
|
|
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
|
|
|
|
>>> for _ in range(10):
|
|
|
|
>>> trainer.train()
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def last_output(self):
|
|
|
|
"""Returns the last output returned from calling the model."""
|
|
|
|
return self._last_output
|
2019-09-07 11:50:18 -07:00
|
|
|
|
|
|
|
def context(self):
|
|
|
|
"""Returns a contextmanager for the current forward pass."""
|
|
|
|
return NullContextManager()
|
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class NullContextManager:
|
2019-09-07 11:50:18 -07:00
|
|
|
"""No-op context manager"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __exit__(self, *args):
|
|
|
|
pass
|