2017-07-17 01:58:54 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
from collections import OrderedDict
|
2019-07-27 02:08:16 -07:00
|
|
|
import logging
|
2018-10-20 15:21:22 -07:00
|
|
|
import gym
|
2017-08-24 12:43:51 -07:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
from ray.rllib.models.tf.misc import linear, normc_initializer
|
2018-10-20 15:21:22 -07:00
|
|
|
from ray.rllib.models.preprocessors import get_preprocessor
|
2019-02-24 15:36:13 -08:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
2019-05-10 20:36:18 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
2019-07-27 02:08:16 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
2018-10-20 15:21:22 -07:00
|
|
|
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
class Model(object):
|
2019-07-27 02:08:16 -07:00
|
|
|
"""This class is deprecated, please use TFModelV2 instead."""
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-07-19 15:30:36 -07:00
|
|
|
def __init__(self,
|
2018-10-20 15:21:22 -07:00
|
|
|
input_dict,
|
|
|
|
obs_space,
|
2019-03-10 04:23:12 +01:00
|
|
|
action_space,
|
2018-07-19 15:30:36 -07:00
|
|
|
num_outputs,
|
|
|
|
options,
|
|
|
|
state_in=None,
|
|
|
|
seq_lens=None):
|
2018-10-20 15:21:22 -07:00
|
|
|
assert isinstance(input_dict, dict), input_dict
|
2018-06-27 22:51:04 -07:00
|
|
|
|
|
|
|
# Default attribute values for the non-RNN case
|
2018-06-26 13:17:15 -07:00
|
|
|
self.state_init = []
|
2018-07-17 06:55:46 +02:00
|
|
|
self.state_in = state_in or []
|
2018-06-26 13:17:15 -07:00
|
|
|
self.state_out = []
|
2019-02-24 15:36:13 -08:00
|
|
|
self.obs_space = obs_space
|
2019-03-10 04:23:12 +01:00
|
|
|
self.action_space = action_space
|
2019-02-24 15:36:13 -08:00
|
|
|
self.num_outputs = num_outputs
|
|
|
|
self.options = options
|
|
|
|
self.scope = tf.get_variable_scope()
|
|
|
|
self.session = tf.get_default_session()
|
2019-07-03 15:59:47 -07:00
|
|
|
self.input_dict = input_dict
|
2018-07-17 06:55:46 +02:00
|
|
|
if seq_lens is not None:
|
|
|
|
self.seq_lens = seq_lens
|
|
|
|
else:
|
|
|
|
self.seq_lens = tf.placeholder(
|
|
|
|
dtype=tf.int32, shape=[None], name="seq_lens")
|
2018-06-27 22:51:04 -07:00
|
|
|
|
2018-11-14 14:14:07 -08:00
|
|
|
self._num_outputs = num_outputs
|
2018-10-16 15:55:11 -07:00
|
|
|
if options.get("free_log_std"):
|
2017-08-24 12:43:51 -07:00
|
|
|
assert num_outputs % 2 == 0
|
|
|
|
num_outputs = num_outputs // 2
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
ok = True
|
2018-10-20 15:21:22 -07:00
|
|
|
try:
|
2019-02-24 15:36:13 -08:00
|
|
|
restored = input_dict.copy()
|
|
|
|
restored["obs"] = restore_original_dimensions(
|
|
|
|
input_dict["obs"], obs_space)
|
2018-10-20 15:21:22 -07:00
|
|
|
self.outputs, self.last_layer = self._build_layers_v2(
|
2019-02-24 15:36:13 -08:00
|
|
|
restored, num_outputs, options)
|
2018-10-20 15:21:22 -07:00
|
|
|
except NotImplementedError:
|
2019-07-03 15:59:47 -07:00
|
|
|
ok = False
|
|
|
|
# In TF 1.14, you cannot construct variable scopes in exception
|
|
|
|
# handlers so we have to set the OK flag and check it here:
|
|
|
|
if not ok:
|
2018-10-20 15:21:22 -07:00
|
|
|
self.outputs, self.last_layer = self._build_layers(
|
|
|
|
input_dict["obs"], num_outputs, options)
|
|
|
|
|
2018-11-14 14:14:07 -08:00
|
|
|
if options.get("free_log_std", False):
|
|
|
|
log_std = tf.get_variable(
|
|
|
|
name="log_std",
|
|
|
|
shape=[num_outputs],
|
|
|
|
initializer=tf.zeros_initializer)
|
|
|
|
self.outputs = tf.concat(
|
|
|
|
[self.outputs, 0.0 * self.outputs + log_std], 1)
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
def _build_layers(self, inputs, num_outputs, options):
|
|
|
|
"""Builds and returns the output and last layer of the network.
|
|
|
|
|
|
|
|
Deprecated: use _build_layers_v2 instead, which has better support
|
|
|
|
for dict and tuple spaces.
|
|
|
|
"""
|
2017-07-17 01:58:54 -07:00
|
|
|
raise NotImplementedError
|
2018-10-20 15:21:22 -07:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-10-20 15:21:22 -07:00
|
|
|
def _build_layers_v2(self, input_dict, num_outputs, options):
|
|
|
|
"""Define the layers of a custom model.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
input_dict (dict): Dictionary of input tensors, including "obs",
|
2018-11-29 13:33:39 -08:00
|
|
|
"prev_action", "prev_reward", "is_training".
|
2018-10-20 15:21:22 -07:00
|
|
|
num_outputs (int): Output tensor must be of size
|
|
|
|
[BATCH_SIZE, num_outputs].
|
|
|
|
options (dict): Model options.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(outputs, feature_layer): Tensors of size [BATCH_SIZE, num_outputs]
|
|
|
|
and [BATCH_SIZE, desired_feature_size].
|
|
|
|
|
|
|
|
When using dict or tuple observation spaces, you can access
|
|
|
|
the nested sub-observation batches here as well:
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> print(input_dict)
|
|
|
|
{'prev_actions': <tf.Tensor shape=(?,) dtype=int64>,
|
|
|
|
'prev_rewards': <tf.Tensor shape=(?,) dtype=float32>,
|
2018-11-29 13:33:39 -08:00
|
|
|
'is_training': <tf.Tensor shape=(), dtype=bool>,
|
2018-10-20 15:21:22 -07:00
|
|
|
'obs': OrderedDict([
|
|
|
|
('sensors', OrderedDict([
|
|
|
|
('front_cam', [
|
|
|
|
<tf.Tensor shape=(?, 10, 10, 3) dtype=float32>,
|
|
|
|
<tf.Tensor shape=(?, 10, 10, 3) dtype=float32>]),
|
|
|
|
('position', <tf.Tensor shape=(?, 3) dtype=float32>),
|
|
|
|
('velocity', <tf.Tensor shape=(?, 3) dtype=float32>)]))])}
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2018-10-29 19:37:27 -07:00
|
|
|
def value_function(self):
|
|
|
|
"""Builds the value function output.
|
|
|
|
|
|
|
|
This method can be overridden to customize the implementation of the
|
|
|
|
value function (e.g., not sharing hidden layers).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor of size [BATCH_SIZE] for the value function.
|
|
|
|
"""
|
|
|
|
return tf.reshape(
|
|
|
|
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2019-03-02 22:57:51 -08:00
|
|
|
def custom_loss(self, policy_loss, loss_inputs):
|
2019-02-24 15:36:13 -08:00
|
|
|
"""Override to customize the loss function used to optimize this model.
|
|
|
|
|
|
|
|
This can be used to incorporate self-supervised losses (by defining
|
|
|
|
a loss over existing input and output tensors of this model), and
|
|
|
|
supervised losses (by defining losses over a variable-sharing copy of
|
|
|
|
this model's layers).
|
2018-11-12 18:55:24 -08:00
|
|
|
|
2019-02-24 15:36:13 -08:00
|
|
|
You can find an runnable example in examples/custom_loss.py.
|
|
|
|
|
|
|
|
Arguments:
|
2019-05-20 16:46:05 -07:00
|
|
|
policy_loss (Tensor): scalar policy loss from the policy.
|
2019-03-02 22:57:51 -08:00
|
|
|
loss_inputs (dict): map of input placeholders for rollout data.
|
2018-11-12 18:55:24 -08:00
|
|
|
|
|
|
|
Returns:
|
2019-02-24 15:36:13 -08:00
|
|
|
Scalar tensor for the customized loss for this model.
|
2018-11-12 18:55:24 -08:00
|
|
|
"""
|
2019-02-24 15:36:13 -08:00
|
|
|
if self.loss() is not None:
|
|
|
|
raise DeprecationWarning(
|
|
|
|
"self.loss() is deprecated, use self.custom_loss() instead.")
|
|
|
|
return policy_loss
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def custom_stats(self):
|
|
|
|
"""Override to return custom metrics from your model.
|
|
|
|
|
|
|
|
The stats will be reported as part of the learner stats, i.e.,
|
|
|
|
info:
|
|
|
|
learner:
|
|
|
|
model:
|
|
|
|
key1: metric1
|
|
|
|
key2: metric2
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict of string keys to scalar tensors.
|
|
|
|
"""
|
|
|
|
return {}
|
|
|
|
|
|
|
|
def loss(self):
|
|
|
|
"""Deprecated: use self.custom_loss()."""
|
|
|
|
return None
|
2018-11-12 18:55:24 -08:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
@classmethod
|
|
|
|
def get_initial_state(cls, obs_space, action_space, num_outputs, options):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"In order to use recurrent models with ModelV2, you should define "
|
|
|
|
"the get_initial_state @classmethod on your custom model class.")
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
def _validate_output_shape(self):
|
|
|
|
"""Checks that the model has the correct number of outputs."""
|
|
|
|
try:
|
|
|
|
out = tf.convert_to_tensor(self.outputs)
|
|
|
|
shape = out.shape.as_list()
|
|
|
|
except Exception:
|
|
|
|
raise ValueError("Output is not a tensor: {}".format(self.outputs))
|
|
|
|
else:
|
|
|
|
if len(shape) != 2 or shape[1] != self._num_outputs:
|
|
|
|
raise ValueError(
|
|
|
|
"Expected output shape of [None, {}], got {}".format(
|
|
|
|
self._num_outputs, shape))
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
|
2019-09-19 12:10:31 -07:00
|
|
|
@DeveloperAPI
|
|
|
|
def flatten(obs, framework):
|
|
|
|
"""Flatten the given tensor."""
|
|
|
|
if framework == "tf":
|
|
|
|
return tf.layers.flatten(obs)
|
|
|
|
elif framework == "torch":
|
|
|
|
import torch
|
|
|
|
return torch.flatten(obs, start_dim=1)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("flatten", framework)
|
|
|
|
|
|
|
|
|
2019-02-24 15:36:13 -08:00
|
|
|
@DeveloperAPI
|
|
|
|
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
|
|
|
"""Unpacks Dict and Tuple space observations into their original form.
|
|
|
|
|
|
|
|
This is needed since we flatten Dict and Tuple observations in transit.
|
|
|
|
Before sending them to the model though, we should unflatten them into
|
|
|
|
Dicts or Tuples of tensors.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
obs: The flattened observation tensor.
|
|
|
|
obs_space: The flattened obs space. If this has the `original_space`
|
|
|
|
attribute, we will unflatten the tensor to that shape.
|
|
|
|
tensorlib: The library used to unflatten (reshape) the array/tensor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
single tensor or dict / tuple of tensors matching the original
|
|
|
|
observation space.
|
|
|
|
"""
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
if hasattr(obs_space, "original_space"):
|
2019-07-03 15:59:47 -07:00
|
|
|
if tensorlib == "tf":
|
|
|
|
tensorlib = tf
|
|
|
|
elif tensorlib == "torch":
|
|
|
|
import torch
|
|
|
|
tensorlib = torch
|
2019-02-24 15:36:13 -08:00
|
|
|
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
|
|
|
|
else:
|
|
|
|
return obs
|
2018-10-20 15:21:22 -07:00
|
|
|
|
|
|
|
|
2019-11-21 15:55:56 -08:00
|
|
|
# Cache of preprocessors, for if the user is calling unpack obs often.
|
|
|
|
_cache = {}
|
|
|
|
|
|
|
|
|
2018-12-18 10:40:01 -08:00
|
|
|
def _unpack_obs(obs, space, tensorlib=tf):
|
|
|
|
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
obs: The flattened observation tensor
|
|
|
|
space: The original space prior to flattening
|
|
|
|
tensorlib: The library used to unflatten (reshape) the array/tensor
|
|
|
|
"""
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
if (isinstance(space, gym.spaces.Dict)
|
|
|
|
or isinstance(space, gym.spaces.Tuple)):
|
2019-11-21 15:55:56 -08:00
|
|
|
if id(space) in _cache:
|
|
|
|
prep = _cache[id(space)]
|
|
|
|
else:
|
|
|
|
prep = get_preprocessor(space)(space)
|
|
|
|
# Make an attempt to cache the result, if enough space left.
|
|
|
|
if len(_cache) < 999:
|
|
|
|
_cache[id(space)] = prep
|
2018-10-20 15:21:22 -07:00
|
|
|
if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]:
|
|
|
|
raise ValueError(
|
|
|
|
"Expected flattened obs shape of [None, {}], got {}".format(
|
|
|
|
prep.shape[0], obs.shape))
|
|
|
|
assert len(prep.preprocessors) == len(space.spaces), \
|
|
|
|
(len(prep.preprocessors) == len(space.spaces))
|
|
|
|
offset = 0
|
|
|
|
if isinstance(space, gym.spaces.Tuple):
|
|
|
|
u = []
|
|
|
|
for p, v in zip(prep.preprocessors, space.spaces):
|
|
|
|
obs_slice = obs[:, offset:offset + p.size]
|
|
|
|
offset += p.size
|
|
|
|
u.append(
|
|
|
|
_unpack_obs(
|
2018-12-18 10:40:01 -08:00
|
|
|
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
|
|
|
v,
|
|
|
|
tensorlib=tensorlib))
|
2018-10-20 15:21:22 -07:00
|
|
|
else:
|
|
|
|
u = OrderedDict()
|
|
|
|
for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
|
|
|
|
obs_slice = obs[:, offset:offset + p.size]
|
|
|
|
offset += p.size
|
|
|
|
u[k] = _unpack_obs(
|
2018-12-18 10:40:01 -08:00
|
|
|
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
|
|
|
v,
|
|
|
|
tensorlib=tensorlib)
|
2018-10-20 15:21:22 -07:00
|
|
|
return u
|
|
|
|
else:
|
|
|
|
return obs
|