2020-04-06 20:56:16 +02:00
|
|
|
from functools import partial
|
2017-07-17 01:58:54 -07:00
|
|
|
import gym
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-01-18 19:51:31 -08:00
|
|
|
import numpy as np
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2017-12-28 13:19:04 -08:00
|
|
|
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
2019-08-06 18:13:16 +00:00
|
|
|
RLLIB_ACTION_DIST, _global_registry
|
2020-02-19 21:18:45 +01:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2018-01-05 21:32:41 -08:00
|
|
|
from ray.rllib.models.preprocessors import get_preprocessor
|
2019-07-27 02:08:16 -07:00
|
|
|
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
|
|
|
|
from ray.rllib.models.tf.lstm_v1 import LSTM
|
|
|
|
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
|
2020-06-05 15:40:30 +02:00
|
|
|
from ray.rllib.models.tf.recurrent_net import LSTMWrapper
|
2020-04-28 14:59:16 +02:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical, \
|
|
|
|
Deterministic, DiagGaussian, Dirichlet, \
|
|
|
|
MultiActionDistribution, MultiCategorical
|
2019-07-27 02:08:16 -07:00
|
|
|
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
2020-04-28 14:59:16 +02:00
|
|
|
TorchDeterministic, TorchDiagGaussian, \
|
|
|
|
TorchMultiActionDistribution, TorchMultiCategorical
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.rllib.utils import try_import_tree
|
2019-07-07 15:06:41 -07:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
2020-05-27 10:19:47 +02:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
2019-07-07 15:06:41 -07:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2020-06-08 23:04:50 -07:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2020-05-27 10:21:30 +02:00
|
|
|
from ray.rllib.utils.spaces.simplex import Simplex
|
|
|
|
from ray.rllib.utils.spaces.space_utils import flatten_space
|
2019-05-10 20:36:18 -07:00
|
|
|
|
|
|
|
tf = try_import_tf()
|
2020-04-28 14:59:16 +02:00
|
|
|
tree = try_import_tree()
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
# yapf: disable
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_begin__
|
|
|
|
MODEL_DEFAULTS = {
|
2017-12-28 13:19:04 -08:00
|
|
|
# === Built-in options ===
|
2018-08-02 14:29:40 -07:00
|
|
|
# Filter config. List of [out_channels, kernel, stride] for each filter
|
2018-10-16 15:55:11 -07:00
|
|
|
"conv_filters": None,
|
|
|
|
# Nonlinearity for built-in convnet
|
|
|
|
"conv_activation": "relu",
|
|
|
|
# Nonlinearity for fully connected net (tanh, relu)
|
|
|
|
"fcnet_activation": "tanh",
|
|
|
|
# Number of hidden layers for fully connected net
|
|
|
|
"fcnet_hiddens": [256, 256],
|
2020-05-12 10:14:05 -07:00
|
|
|
# For DiagGaussian action distributions, make the second half of the model
|
|
|
|
# outputs floating bias variables instead of state-dependent. This only
|
|
|
|
# has an effect is using the default fully connected net.
|
2018-10-16 15:55:11 -07:00
|
|
|
"free_log_std": False,
|
2019-07-03 15:59:47 -07:00
|
|
|
# Whether to skip the final linear layer used to resize the hidden layer
|
|
|
|
# outputs to size `num_outputs`. If True, then the last hidden layer
|
|
|
|
# should already match num_outputs.
|
|
|
|
"no_final_linear": False,
|
|
|
|
# Whether layers should be shared for the value function.
|
2019-07-07 15:06:41 -07:00
|
|
|
"vf_share_layers": True,
|
2018-10-16 15:55:11 -07:00
|
|
|
|
|
|
|
# == LSTM ==
|
2020-06-05 15:40:30 +02:00
|
|
|
# Whether to wrap the model with an LSTM.
|
2018-10-16 15:55:11 -07:00
|
|
|
"use_lstm": False,
|
2020-06-05 15:40:30 +02:00
|
|
|
# Max seq len for training the LSTM, defaults to 20.
|
2018-10-16 15:55:11 -07:00
|
|
|
"max_seq_len": 20,
|
2020-06-05 15:40:30 +02:00
|
|
|
# Size of the LSTM cell.
|
2018-10-16 15:55:11 -07:00
|
|
|
"lstm_cell_size": 256,
|
2020-06-05 15:40:30 +02:00
|
|
|
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
|
2018-10-20 15:21:22 -07:00
|
|
|
"lstm_use_prev_action_reward": False,
|
2019-07-03 15:59:47 -07:00
|
|
|
# When using modelv1 models with a modelv2 algorithm, you may have to
|
|
|
|
# define the state shape here (e.g., [256, 256]).
|
|
|
|
"state_shape": None,
|
2018-10-16 15:55:11 -07:00
|
|
|
|
|
|
|
# == Atari ==
|
|
|
|
# Whether to enable framestack for Atari envs
|
|
|
|
"framestack": True,
|
|
|
|
# Final resized frame dimension
|
|
|
|
"dim": 84,
|
|
|
|
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
|
|
|
|
"grayscale": False,
|
|
|
|
# (deprecated) Changes frame to range from [-1, 1] if true
|
|
|
|
"zero_mean": True,
|
2017-12-28 13:19:04 -08:00
|
|
|
|
|
|
|
# === Options for custom models ===
|
2018-10-16 15:55:11 -07:00
|
|
|
# Name of a custom model to use
|
|
|
|
"custom_model": None,
|
2020-05-27 10:19:47 +02:00
|
|
|
# Extra options to pass to the custom classes.
|
|
|
|
# These will be available in the Model's
|
|
|
|
"custom_model_config": {},
|
2020-02-19 21:18:45 +01:00
|
|
|
# Name of a custom action distribution to use.
|
2019-08-06 18:13:16 +00:00
|
|
|
"custom_action_dist": None,
|
2020-01-18 23:30:09 -08:00
|
|
|
# Custom preprocessors are deprecated. Please use a wrapper class around
|
|
|
|
# your environment instead to preprocess observations.
|
|
|
|
"custom_preprocessor": None,
|
2020-05-27 10:19:47 +02:00
|
|
|
|
|
|
|
# Deprecated config keys.
|
|
|
|
"custom_options": DEPRECATED_VALUE,
|
2018-10-16 15:55:11 -07:00
|
|
|
}
|
|
|
|
# __sphinx_doc_end__
|
2018-10-21 23:43:57 -07:00
|
|
|
# yapf: enable
|
2017-09-02 17:20:56 -07:00
|
|
|
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class ModelCatalog:
|
2018-01-01 11:10:44 -08:00
|
|
|
"""Registry of models, preprocessors, and action distributions for envs.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> prep = ModelCatalog.get_preprocessor(env)
|
|
|
|
>>> observation = prep.transform(raw_observation)
|
|
|
|
|
2019-08-10 14:05:12 -07:00
|
|
|
>>> dist_class, dist_dim = ModelCatalog.get_action_dist(
|
2020-06-05 15:40:30 +02:00
|
|
|
... env.action_space, {})
|
|
|
|
>>> model = ModelCatalog.get_model_v2(
|
|
|
|
... obs_space, action_space, num_outputs, options)
|
2019-08-10 14:05:12 -07:00
|
|
|
>>> dist = dist_class(model.outputs, model)
|
2018-01-01 11:10:44 -08:00
|
|
|
>>> action = dist.sample()
|
|
|
|
"""
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
@staticmethod
|
2019-01-23 21:27:26 -08:00
|
|
|
@DeveloperAPI
|
2020-02-11 00:22:07 +01:00
|
|
|
def get_action_dist(action_space,
|
|
|
|
config,
|
|
|
|
dist_type=None,
|
2020-02-19 21:18:45 +01:00
|
|
|
framework="tf",
|
|
|
|
**kwargs):
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Returns a distribution class and size for the given action space.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
2020-02-19 21:18:45 +01:00
|
|
|
config (Optional[dict]): Optional model config.
|
2020-06-16 09:01:20 +02:00
|
|
|
dist_type (Optional[str]): Identifier of the action distribution
|
|
|
|
interpreted as a hint.
|
2020-05-27 16:19:13 +02:00
|
|
|
framework (str): One of "tf", "tfe", or "torch".
|
2020-02-19 21:18:45 +01:00
|
|
|
kwargs (dict): Optional kwargs to pass on to the Distribution's
|
|
|
|
constructor.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Returns:
|
2020-06-16 09:01:20 +02:00
|
|
|
Tuple:
|
|
|
|
- dist_class (ActionDistribution): Python class of the
|
|
|
|
distribution.
|
|
|
|
- dist_dim (int): The size of the input vector to the
|
|
|
|
distribution.
|
2017-07-17 01:58:54 -07:00
|
|
|
"""
|
2020-05-27 16:19:13 +02:00
|
|
|
|
2020-01-21 08:06:50 +01:00
|
|
|
dist = None
|
2018-10-16 15:55:11 -07:00
|
|
|
config = config or MODEL_DEFAULTS
|
2020-02-19 21:18:45 +01:00
|
|
|
# Custom distribution given.
|
2019-08-06 18:13:16 +00:00
|
|
|
if config.get("custom_action_dist"):
|
|
|
|
action_dist_name = config["custom_action_dist"]
|
|
|
|
logger.debug(
|
|
|
|
"Using custom action distribution {}".format(action_dist_name))
|
|
|
|
dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
|
2020-02-19 21:18:45 +01:00
|
|
|
# Dist_type is given directly as a class.
|
|
|
|
elif type(dist_type) is type and \
|
|
|
|
issubclass(dist_type, ActionDistribution) and \
|
2020-04-28 14:59:16 +02:00
|
|
|
dist_type not in (
|
|
|
|
MultiActionDistribution, TorchMultiActionDistribution):
|
2020-02-19 21:18:45 +01:00
|
|
|
dist = dist_type
|
|
|
|
# Box space -> DiagGaussian OR Deterministic.
|
2019-08-06 18:13:16 +00:00
|
|
|
elif isinstance(action_space, gym.spaces.Box):
|
2018-10-26 16:55:00 -07:00
|
|
|
if len(action_space.shape) > 1:
|
2019-07-07 15:06:41 -07:00
|
|
|
raise UnsupportedSpaceException(
|
2018-10-26 16:55:00 -07:00
|
|
|
"Action space has multiple dimensions "
|
|
|
|
"{}. ".format(action_space.shape) +
|
|
|
|
"Consider reshaping this into a single dimension, "
|
2019-08-06 18:13:16 +00:00
|
|
|
"using a custom action distribution, "
|
2018-10-26 16:55:00 -07:00
|
|
|
"using a Tuple action space, or the multi-agent API.")
|
2020-02-19 21:18:45 +01:00
|
|
|
# TODO(sven): Check for bounds and return SquashedNormal, etc..
|
2017-07-17 01:58:54 -07:00
|
|
|
if dist_type is None:
|
2020-05-27 16:19:13 +02:00
|
|
|
dist = TorchDiagGaussian if framework == "torch" \
|
|
|
|
else DiagGaussian
|
2018-10-16 15:55:11 -07:00
|
|
|
elif dist_type == "deterministic":
|
2020-05-27 16:19:13 +02:00
|
|
|
dist = TorchDeterministic if framework == "torch" \
|
|
|
|
else Deterministic
|
2020-02-19 21:18:45 +01:00
|
|
|
# Discrete Space -> Categorical.
|
2017-07-17 01:58:54 -07:00
|
|
|
elif isinstance(action_space, gym.spaces.Discrete):
|
2020-05-27 16:19:13 +02:00
|
|
|
dist = TorchCategorical if framework == "torch" else Categorical
|
2020-04-28 14:59:16 +02:00
|
|
|
# Tuple/Dict Spaces -> MultiAction.
|
|
|
|
elif dist_type in (MultiActionDistribution,
|
|
|
|
TorchMultiActionDistribution) or \
|
|
|
|
isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
|
|
|
|
flat_action_space = flatten_space(action_space)
|
|
|
|
child_dists_and_in_lens = tree.map_structure(
|
|
|
|
lambda s: ModelCatalog.get_action_dist(
|
|
|
|
s, config, framework=framework), flat_action_space)
|
|
|
|
child_dists = [e[0] for e in child_dists_and_in_lens]
|
2020-05-08 08:20:18 +02:00
|
|
|
input_lens = [int(e[1]) for e in child_dists_and_in_lens]
|
2018-07-19 15:30:36 -07:00
|
|
|
return partial(
|
2020-04-28 14:59:16 +02:00
|
|
|
(TorchMultiActionDistribution
|
|
|
|
if framework == "torch" else MultiActionDistribution),
|
2018-08-11 10:57:40 -07:00
|
|
|
action_space=action_space,
|
2020-04-28 14:59:16 +02:00
|
|
|
child_distributions=child_dists,
|
2020-05-08 08:20:18 +02:00
|
|
|
input_lens=input_lens), int(sum(input_lens))
|
2020-02-19 21:18:45 +01:00
|
|
|
# Simplex -> Dirichlet.
|
2019-02-17 04:44:59 +08:00
|
|
|
elif isinstance(action_space, Simplex):
|
2020-01-21 08:06:50 +01:00
|
|
|
if framework == "torch":
|
|
|
|
# TODO(sven): implement
|
|
|
|
raise NotImplementedError(
|
2020-02-11 00:22:07 +01:00
|
|
|
"Simplex action spaces not supported for torch.")
|
2019-08-06 18:13:16 +00:00
|
|
|
dist = Dirichlet
|
2020-02-19 21:18:45 +01:00
|
|
|
# MultiDiscrete -> MultiCategorical.
|
2019-08-06 18:13:16 +00:00
|
|
|
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
2020-05-27 16:19:13 +02:00
|
|
|
dist = TorchMultiCategorical if framework == "torch" else \
|
|
|
|
MultiCategorical
|
2020-03-04 09:41:40 +01:00
|
|
|
return partial(dist, input_lens=action_space.nvec), \
|
2019-05-29 20:41:02 -07:00
|
|
|
int(sum(action_space.nvec))
|
2020-02-19 21:18:45 +01:00
|
|
|
# Unknown type -> Error.
|
2020-01-21 08:06:50 +01:00
|
|
|
else:
|
2020-02-11 00:22:07 +01:00
|
|
|
raise NotImplementedError("Unsupported args: {} {}".format(
|
|
|
|
action_space, dist_type))
|
2019-03-13 04:32:11 +01:00
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
return dist, dist.required_model_output_shape(action_space, config)
|
|
|
|
|
2018-01-18 19:51:31 -08:00
|
|
|
@staticmethod
|
2019-01-23 21:27:26 -08:00
|
|
|
@DeveloperAPI
|
2019-08-23 02:21:11 -04:00
|
|
|
def get_action_shape(action_space):
|
|
|
|
"""Returns action tensor dtype and shape for the action space.
|
2018-01-18 19:51:31 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
|
|
|
Returns:
|
2019-08-23 02:21:11 -04:00
|
|
|
(dtype, shape): Dtype and shape of the actions tensor.
|
2018-01-18 19:51:31 -08:00
|
|
|
"""
|
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
if isinstance(action_space, gym.spaces.Discrete):
|
2019-08-23 02:21:11 -04:00
|
|
|
return (tf.int64, (None, ))
|
2019-08-06 18:13:16 +00:00
|
|
|
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
|
2019-08-23 02:21:11 -04:00
|
|
|
return (tf.float32, (None, ) + action_space.shape)
|
2019-08-06 18:13:16 +00:00
|
|
|
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
2019-08-23 02:21:11 -04:00
|
|
|
return (tf.as_dtype(action_space.dtype),
|
|
|
|
(None, ) + action_space.shape)
|
2020-04-28 14:59:16 +02:00
|
|
|
elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
|
|
|
|
flat_action_space = flatten_space(action_space)
|
2018-01-18 19:51:31 -08:00
|
|
|
size = 0
|
2018-01-24 11:03:43 -08:00
|
|
|
all_discrete = True
|
2020-04-28 14:59:16 +02:00
|
|
|
for i in range(len(flat_action_space)):
|
|
|
|
if isinstance(flat_action_space[i], gym.spaces.Discrete):
|
2018-01-24 11:03:43 -08:00
|
|
|
size += 1
|
|
|
|
else:
|
|
|
|
all_discrete = False
|
2020-04-28 14:59:16 +02:00
|
|
|
size += np.product(flat_action_space[i].shape)
|
2020-02-25 17:16:29 -05:00
|
|
|
size = int(size)
|
2019-08-23 02:21:11 -04:00
|
|
|
return (tf.int64 if all_discrete else tf.float32, (None, size))
|
2018-01-18 19:51:31 -08:00
|
|
|
else:
|
2020-04-28 14:59:16 +02:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Action space {} not supported".format(action_space))
|
2018-01-18 19:51:31 -08:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
@staticmethod
|
|
|
|
@DeveloperAPI
|
2020-04-28 14:59:16 +02:00
|
|
|
def get_action_placeholder(action_space, name="action"):
|
2019-08-23 02:21:11 -04:00
|
|
|
"""Returns an action placeholder consistent with the action space
|
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
2020-04-28 14:59:16 +02:00
|
|
|
name (str): An optional string to name the placeholder by.
|
|
|
|
Default: "action".
|
2019-08-23 02:21:11 -04:00
|
|
|
Returns:
|
|
|
|
action_placeholder (Tensor): A placeholder for the actions
|
|
|
|
"""
|
|
|
|
|
|
|
|
dtype, shape = ModelCatalog.get_action_shape(action_space)
|
|
|
|
|
2020-04-28 14:59:16 +02:00
|
|
|
return tf.placeholder(dtype, shape=shape, name=name)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
@staticmethod
|
2019-07-27 02:08:16 -07:00
|
|
|
@DeveloperAPI
|
2019-07-03 15:59:47 -07:00
|
|
|
def get_model_v2(obs_space,
|
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
2020-02-19 21:18:45 +01:00
|
|
|
framework="tf",
|
2019-07-27 02:08:16 -07:00
|
|
|
name="default_model",
|
2019-07-03 15:59:47 -07:00
|
|
|
model_interface=None,
|
2019-07-24 13:55:55 -07:00
|
|
|
default_model=None,
|
2019-07-03 15:59:47 -07:00
|
|
|
**model_kwargs):
|
|
|
|
"""Returns a suitable model compatible with given spaces and output.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
obs_space (Space): Observation space of the target gym env. This
|
|
|
|
may have an `original_space` attribute that specifies how to
|
|
|
|
unflatten the tensor into a ragged tensor.
|
|
|
|
action_space (Space): Action space of the target gym env.
|
|
|
|
num_outputs (int): The size of the output vector of the model.
|
2020-05-27 16:19:13 +02:00
|
|
|
framework (str): One of "tf", "tfe", or "torch".
|
2019-07-03 15:59:47 -07:00
|
|
|
name (str): Name (scope) for the model.
|
|
|
|
model_interface (cls): Interface required for the model
|
2019-07-24 13:55:55 -07:00
|
|
|
default_model (cls): Override the default class for the model. This
|
|
|
|
only has an effect when not using a custom model
|
2019-07-03 15:59:47 -07:00
|
|
|
model_kwargs (dict): args to pass to the ModelV2 constructor
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
model (ModelV2): Model to use for the policy.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if model_config.get("custom_model"):
|
2020-05-27 10:19:47 +02:00
|
|
|
|
|
|
|
if "custom_options" in model_config and \
|
|
|
|
model_config["custom_options"] != DEPRECATED_VALUE:
|
|
|
|
deprecation_warning(
|
|
|
|
"model.custom_options",
|
|
|
|
"model.custom_model_config",
|
|
|
|
error=False)
|
|
|
|
model_config["custom_model_config"] = \
|
|
|
|
model_config.pop("custom_options")
|
|
|
|
|
2020-05-18 17:26:40 +02:00
|
|
|
if isinstance(model_config["custom_model"], type):
|
|
|
|
model_cls = model_config["custom_model"]
|
|
|
|
else:
|
|
|
|
model_cls = _global_registry.get(RLLIB_MODEL,
|
|
|
|
model_config["custom_model"])
|
2020-06-05 15:40:30 +02:00
|
|
|
|
2020-05-18 17:26:40 +02:00
|
|
|
# TODO(sven): Hard-deprecate Model(V1).
|
2019-07-03 15:59:47 -07:00
|
|
|
if issubclass(model_cls, ModelV2):
|
2020-05-08 08:20:18 +02:00
|
|
|
logger.info("Wrapping {} as {}".format(model_cls,
|
|
|
|
model_interface))
|
|
|
|
model_cls = ModelCatalog._wrap_if_needed(
|
|
|
|
model_cls, model_interface)
|
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
if framework in ["tf", "tfe"]:
|
2020-05-08 08:20:18 +02:00
|
|
|
# Track and warn if vars were created but not registered.
|
2019-07-25 11:02:53 -07:00
|
|
|
created = set()
|
|
|
|
|
|
|
|
def track_var_creation(next_creator, **kw):
|
|
|
|
v = next_creator(**kw)
|
|
|
|
created.add(v)
|
|
|
|
return v
|
|
|
|
|
|
|
|
with tf.variable_creator_scope(track_var_creation):
|
2020-05-18 17:26:40 +02:00
|
|
|
# Try calling with kwargs first (custom ModelV2 should
|
|
|
|
# accept these as kwargs, not get them from
|
2020-05-27 10:19:47 +02:00
|
|
|
# config["custom_model_config"] anymore).
|
2020-05-18 17:26:40 +02:00
|
|
|
try:
|
|
|
|
instance = model_cls(obs_space, action_space,
|
|
|
|
num_outputs, model_config,
|
|
|
|
name, **model_kwargs)
|
|
|
|
except TypeError as e:
|
|
|
|
# Keyword error: Try old way w/o kwargs.
|
|
|
|
if "__init__() got an unexpected " in e.args[0]:
|
|
|
|
logger.warning(
|
|
|
|
"Custom ModelV2 should accept all custom "
|
|
|
|
"options as **kwargs, instead of expecting"
|
2020-05-27 10:19:47 +02:00
|
|
|
" them in config['custom_model_config']!")
|
2020-05-18 17:26:40 +02:00
|
|
|
instance = model_cls(obs_space, action_space,
|
|
|
|
num_outputs, model_config,
|
|
|
|
name)
|
|
|
|
# Other error -> re-raise.
|
|
|
|
else:
|
|
|
|
raise e
|
2019-07-25 11:02:53 -07:00
|
|
|
registered = set(instance.variables())
|
|
|
|
not_registered = set()
|
|
|
|
for var in created:
|
|
|
|
if var not in registered:
|
|
|
|
not_registered.add(var)
|
|
|
|
if not_registered:
|
|
|
|
raise ValueError(
|
|
|
|
"It looks like variables {} were created as part "
|
|
|
|
"of {} but does not appear in model.variables() "
|
|
|
|
"({}). Did you forget to call "
|
|
|
|
"model.register_variables() on the variables in "
|
|
|
|
"question?".format(not_registered, instance,
|
|
|
|
registered))
|
|
|
|
else:
|
2020-05-08 08:20:18 +02:00
|
|
|
# PyTorch automatically tracks nn.Modules inside the parent
|
|
|
|
# nn.Module's constructor.
|
|
|
|
# TODO(sven): Do this for TF as well.
|
2019-07-03 15:59:47 -07:00
|
|
|
instance = model_cls(obs_space, action_space, num_outputs,
|
|
|
|
model_config, name, **model_kwargs)
|
|
|
|
return instance
|
2020-05-18 17:26:40 +02:00
|
|
|
# TODO(sven): Hard-deprecate Model(V1). This check will be
|
|
|
|
# superflous then.
|
2019-08-23 02:21:11 -04:00
|
|
|
elif tf.executing_eagerly():
|
|
|
|
raise ValueError(
|
|
|
|
"Eager execution requires a TFModelV2 model to be "
|
|
|
|
"used, however you specified a custom model {}".format(
|
|
|
|
model_cls))
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-05-27 16:19:13 +02:00
|
|
|
if framework in ["tf", "tfe"]:
|
2019-08-23 02:21:11 -04:00
|
|
|
v2_class = None
|
2020-06-05 15:40:30 +02:00
|
|
|
# Try to get a default v2 model.
|
2019-08-23 02:21:11 -04:00
|
|
|
if not model_config.get("custom_model"):
|
2020-04-15 13:25:16 +02:00
|
|
|
v2_class = default_model or ModelCatalog._get_v2_model_class(
|
|
|
|
obs_space, model_config, framework=framework)
|
2020-06-05 15:40:30 +02:00
|
|
|
|
|
|
|
if model_config.get("use_lstm"):
|
|
|
|
wrapped_cls = v2_class
|
|
|
|
forward = wrapped_cls.forward
|
|
|
|
v2_class = ModelCatalog._wrap_if_needed(
|
|
|
|
wrapped_cls, LSTMWrapper)
|
|
|
|
v2_class._wrapped_forward = forward
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
# fallback to a default v1 model
|
|
|
|
if v2_class is None:
|
|
|
|
if tf.executing_eagerly():
|
|
|
|
raise ValueError(
|
|
|
|
"Eager execution requires a TFModelV2 model to be "
|
|
|
|
"used, however there is no default V2 model for this "
|
|
|
|
"observation space: {}, use_lstm={}".format(
|
|
|
|
obs_space, model_config.get("use_lstm")))
|
|
|
|
v2_class = make_v1_wrapper(ModelCatalog.get_model)
|
2020-06-05 15:40:30 +02:00
|
|
|
# Wrap in the requested interface.
|
2019-08-23 02:21:11 -04:00
|
|
|
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
|
2019-07-03 15:59:47 -07:00
|
|
|
return wrapper(obs_space, action_space, num_outputs, model_config,
|
|
|
|
name, **model_kwargs)
|
2019-07-25 11:02:53 -07:00
|
|
|
elif framework == "torch":
|
2020-04-15 13:25:16 +02:00
|
|
|
v2_class = \
|
|
|
|
default_model or ModelCatalog._get_v2_model_class(
|
|
|
|
obs_space, model_config, framework=framework)
|
2020-06-05 15:40:30 +02:00
|
|
|
if model_config.get("use_lstm"):
|
2020-06-10 15:41:59 +02:00
|
|
|
from ray.rllib.models.torch.recurrent_net import LSTMWrapper \
|
|
|
|
as TorchLSTMWrapper
|
2020-06-05 15:40:30 +02:00
|
|
|
wrapped_cls = v2_class
|
|
|
|
forward = wrapped_cls.forward
|
|
|
|
v2_class = ModelCatalog._wrap_if_needed(
|
|
|
|
wrapped_cls, TorchLSTMWrapper)
|
|
|
|
v2_class._wrapped_forward = forward
|
2020-04-15 13:25:16 +02:00
|
|
|
# Wrap in the requested interface.
|
2020-04-06 20:56:16 +02:00
|
|
|
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
|
|
|
|
return wrapper(obs_space, action_space, num_outputs, model_config,
|
|
|
|
name, **model_kwargs)
|
2019-07-25 11:02:53 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
2020-06-25 19:01:32 +02:00
|
|
|
"`framework` must be 'tf|tfe|torch', but is "
|
|
|
|
"{}!".format(framework))
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@staticmethod
|
|
|
|
@DeveloperAPI
|
|
|
|
def get_preprocessor(env, options=None):
|
|
|
|
"""Returns a suitable preprocessor for the given env.
|
|
|
|
|
|
|
|
This is a wrapper for get_preprocessor_for_space().
|
|
|
|
"""
|
|
|
|
|
|
|
|
return ModelCatalog.get_preprocessor_for_space(env.observation_space,
|
|
|
|
options)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@DeveloperAPI
|
|
|
|
def get_preprocessor_for_space(observation_space, options=None):
|
|
|
|
"""Returns a suitable preprocessor for the given observation space.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
observation_space (Space): The input observation space.
|
|
|
|
options (dict): Options to pass to the preprocessor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
preprocessor (Preprocessor): Preprocessor for the observations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
options = options or MODEL_DEFAULTS
|
|
|
|
for k in options.keys():
|
|
|
|
if k not in MODEL_DEFAULTS:
|
|
|
|
raise Exception("Unknown config key `{}`, all keys: {}".format(
|
|
|
|
k, list(MODEL_DEFAULTS)))
|
|
|
|
|
|
|
|
if options.get("custom_preprocessor"):
|
|
|
|
preprocessor = options["custom_preprocessor"]
|
|
|
|
logger.info("Using custom preprocessor {}".format(preprocessor))
|
2020-01-18 23:30:09 -08:00
|
|
|
logger.warning(
|
|
|
|
"DeprecationWarning: Custom preprocessors are deprecated, "
|
|
|
|
"since they sometimes conflict with the built-in "
|
|
|
|
"preprocessors for handling complex observation spaces. "
|
|
|
|
"Please use wrapper classes around your environment "
|
|
|
|
"instead of preprocessors.")
|
2019-07-27 02:08:16 -07:00
|
|
|
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
|
|
|
observation_space, options)
|
|
|
|
else:
|
|
|
|
cls = get_preprocessor(observation_space)
|
|
|
|
prep = cls(observation_space, options)
|
|
|
|
|
|
|
|
logger.debug("Created preprocessor {}: {} -> {}".format(
|
|
|
|
prep, observation_space, prep.shape))
|
|
|
|
return prep
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
|
|
|
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
|
|
|
|
"""Register a custom preprocessor class by name.
|
|
|
|
|
|
|
|
The preprocessor can be later used by specifying
|
|
|
|
{"custom_preprocessor": preprocesor_name} in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
preprocessor_name (str): Name to register the preprocessor under.
|
|
|
|
preprocessor_class (type): Python class of the preprocessor.
|
|
|
|
"""
|
|
|
|
_global_registry.register(RLLIB_PREPROCESSOR, preprocessor_name,
|
|
|
|
preprocessor_class)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
|
|
|
def register_custom_model(model_name, model_class):
|
|
|
|
"""Register a custom model class by name.
|
|
|
|
|
|
|
|
The model can be later used by specifying {"custom_model": model_name}
|
|
|
|
in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_name (str): Name to register the model under.
|
|
|
|
model_class (type): Python class of the model.
|
|
|
|
"""
|
|
|
|
_global_registry.register(RLLIB_MODEL, model_name, model_class)
|
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
|
|
|
def register_custom_action_dist(action_dist_name, action_dist_class):
|
|
|
|
"""Register a custom action distribution class by name.
|
|
|
|
|
|
|
|
The model can be later used by specifying
|
|
|
|
{"custom_action_dist": action_dist_name} in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_name (str): Name to register the action distribution under.
|
|
|
|
model_class (type): Python class of the action distribution.
|
|
|
|
"""
|
|
|
|
_global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
|
|
|
|
action_dist_class)
|
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
@staticmethod
|
|
|
|
def _wrap_if_needed(model_cls, model_interface):
|
2020-06-10 15:41:59 +02:00
|
|
|
assert issubclass(model_cls, ModelV2), model_cls
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
if not model_interface or issubclass(model_cls, model_interface):
|
|
|
|
return model_cls
|
|
|
|
|
|
|
|
class wrapper(model_interface, model_cls):
|
|
|
|
pass
|
|
|
|
|
|
|
|
name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
|
|
|
|
wrapper.__name__ = name
|
|
|
|
wrapper.__qualname__ = name
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
2020-06-16 09:01:20 +02:00
|
|
|
@staticmethod
|
|
|
|
def _get_v2_model_class(input_space, model_config, framework="tf"):
|
|
|
|
if framework == "torch":
|
|
|
|
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
|
|
|
FCNet)
|
|
|
|
from ray.rllib.models.torch.visionnet import (VisionNetwork as
|
|
|
|
VisionNet)
|
|
|
|
else:
|
|
|
|
from ray.rllib.models.tf.fcnet import \
|
|
|
|
FullyConnectedNetwork as FCNet
|
|
|
|
from ray.rllib.models.tf.visionnet import \
|
|
|
|
VisionNetwork as VisionNet
|
|
|
|
|
|
|
|
# Discrete/1D obs-spaces.
|
|
|
|
if isinstance(input_space, gym.spaces.Discrete) or \
|
|
|
|
len(input_space.shape) <= 2:
|
|
|
|
return FCNet
|
|
|
|
# Default Conv2D net.
|
|
|
|
else:
|
|
|
|
return VisionNet
|
|
|
|
|
|
|
|
# -------------------
|
|
|
|
# DEPRECATED METHODS.
|
|
|
|
# -------------------
|
2019-07-27 02:08:16 -07:00
|
|
|
@staticmethod
|
2018-10-20 15:21:22 -07:00
|
|
|
def get_model(input_dict,
|
|
|
|
obs_space,
|
2019-03-10 04:23:12 +01:00
|
|
|
action_space,
|
2018-10-20 15:21:22 -07:00
|
|
|
num_outputs,
|
|
|
|
options,
|
|
|
|
state_in=None,
|
|
|
|
seq_lens=None):
|
2020-05-27 16:19:13 +02:00
|
|
|
"""Deprecated: Use get_model_v2() instead."""
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2020-04-29 12:12:59 +02:00
|
|
|
deprecation_warning("get_model", "get_model_v2", error=False)
|
2018-10-20 15:21:22 -07:00
|
|
|
assert isinstance(input_dict, dict)
|
2018-10-16 15:55:11 -07:00
|
|
|
options = options or MODEL_DEFAULTS
|
2019-03-10 04:23:12 +01:00
|
|
|
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
|
|
|
|
num_outputs, options, state_in,
|
|
|
|
seq_lens)
|
2018-06-27 22:51:04 -07:00
|
|
|
|
|
|
|
if options.get("use_lstm"):
|
2018-10-20 15:21:22 -07:00
|
|
|
copy = dict(input_dict)
|
|
|
|
copy["obs"] = model.last_layer
|
2018-11-23 22:51:08 -08:00
|
|
|
feature_space = gym.spaces.Box(
|
|
|
|
-1, 1, shape=(model.last_layer.shape[1], ))
|
2019-03-10 04:23:12 +01:00
|
|
|
model = LSTM(copy, feature_space, action_space, num_outputs,
|
|
|
|
options, state_in, seq_lens)
|
2018-06-27 22:51:04 -07:00
|
|
|
|
2019-03-10 04:23:12 +01:00
|
|
|
logger.debug(
|
|
|
|
"Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
|
|
|
|
model, input_dict, obs_space, action_space, state_in, seq_lens,
|
|
|
|
model.outputs, model.state_out))
|
2018-11-14 14:14:07 -08:00
|
|
|
|
|
|
|
model._validate_output_shape()
|
2018-06-27 22:51:04 -07:00
|
|
|
return model
|
|
|
|
|
|
|
|
@staticmethod
|
2019-03-10 04:23:12 +01:00
|
|
|
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
|
|
|
|
state_in, seq_lens):
|
2020-04-29 12:12:59 +02:00
|
|
|
deprecation_warning("_get_model", "get_model_v2", error=False)
|
2018-10-16 15:55:11 -07:00
|
|
|
if options.get("custom_model"):
|
2017-12-28 13:19:04 -08:00
|
|
|
model = options["custom_model"]
|
2018-11-27 23:35:19 -08:00
|
|
|
logger.debug("Using custom model {}".format(model))
|
2018-06-19 22:47:00 -07:00
|
|
|
return _global_registry.get(RLLIB_MODEL, model)(
|
2018-10-20 15:21:22 -07:00
|
|
|
input_dict,
|
|
|
|
obs_space,
|
2019-03-10 04:23:12 +01:00
|
|
|
action_space,
|
2018-07-19 15:30:36 -07:00
|
|
|
num_outputs,
|
|
|
|
options,
|
|
|
|
state_in=state_in,
|
|
|
|
seq_lens=seq_lens)
|
2017-12-28 13:19:04 -08:00
|
|
|
|
2019-11-05 11:36:29 -08:00
|
|
|
obs_rank = len(input_dict["obs"].shape) - 1 # drops batch dim
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2019-09-19 12:10:31 -07:00
|
|
|
if obs_rank > 2:
|
2019-03-10 04:23:12 +01:00
|
|
|
return VisionNetwork(input_dict, obs_space, action_space,
|
|
|
|
num_outputs, options)
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2019-03-10 04:23:12 +01:00
|
|
|
return FullyConnectedNetwork(input_dict, obs_space, action_space,
|
|
|
|
num_outputs, options)
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2017-11-12 00:20:33 -08:00
|
|
|
@staticmethod
|
2019-01-03 13:48:33 +08:00
|
|
|
def get_torch_model(obs_space,
|
|
|
|
num_outputs,
|
|
|
|
options=None,
|
|
|
|
default_model_cls=None):
|
2019-07-25 11:02:53 -07:00
|
|
|
raise DeprecationWarning("Please use get_model_v2() instead.")
|