ray/rllib/models/catalog.py

592 lines
25 KiB
Python
Raw Normal View History

[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from functools import partial
import gym
import logging
import numpy as np
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
RLLIB_ACTION_DIST, _global_registry
from ray.rllib.models.action_dist import ActionDistribution
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.models.modelv2 import ModelV2
[carla] [rllib] Add support for carla nav planner and scenarios from paper (#1382) * wip * Sat Dec 30 15:07:28 PST 2017 * log video * video doesn't work well * scenario integration * Sat Dec 30 17:30:22 PST 2017 * Sat Dec 30 17:31:05 PST 2017 * Sat Dec 30 17:31:32 PST 2017 * Sat Dec 30 17:32:16 PST 2017 * Sat Dec 30 17:34:11 PST 2017 * Sat Dec 30 17:34:50 PST 2017 * Sat Dec 30 17:35:34 PST 2017 * Sat Dec 30 17:38:49 PST 2017 * Sat Dec 30 17:40:39 PST 2017 * Sat Dec 30 17:43:00 PST 2017 * Sat Dec 30 17:43:04 PST 2017 * Sat Dec 30 17:45:56 PST 2017 * Sat Dec 30 17:46:26 PST 2017 * Sat Dec 30 17:47:02 PST 2017 * Sat Dec 30 17:51:53 PST 2017 * Sat Dec 30 17:52:54 PST 2017 * Sat Dec 30 17:56:43 PST 2017 * Sat Dec 30 18:27:07 PST 2017 * Sat Dec 30 18:27:52 PST 2017 * fix train * Sat Dec 30 18:41:51 PST 2017 * Sat Dec 30 18:54:11 PST 2017 * Sat Dec 30 18:56:22 PST 2017 * Sat Dec 30 19:05:04 PST 2017 * Sat Dec 30 19:05:23 PST 2017 * Sat Dec 30 19:11:53 PST 2017 * Sat Dec 30 19:14:31 PST 2017 * Sat Dec 30 19:16:20 PST 2017 * Sat Dec 30 19:18:05 PST 2017 * Sat Dec 30 19:18:45 PST 2017 * Sat Dec 30 19:22:44 PST 2017 * Sat Dec 30 19:24:41 PST 2017 * Sat Dec 30 19:26:57 PST 2017 * Sat Dec 30 19:40:37 PST 2017 * wip models * reward bonus * test prep * Sun Dec 31 18:45:25 PST 2017 * Sun Dec 31 18:58:28 PST 2017 * Sun Dec 31 18:59:34 PST 2017 * Sun Dec 31 19:03:33 PST 2017 * Sun Dec 31 19:05:05 PST 2017 * Sun Dec 31 19:09:25 PST 2017 * fix train * kill * add tuple preprocessor * Sun Dec 31 20:38:33 PST 2017 * Sun Dec 31 22:51:24 PST 2017 * Sun Dec 31 23:14:13 PST 2017 * Sun Dec 31 23:16:04 PST 2017 * Mon Jan 1 00:08:35 PST 2018 * Mon Jan 1 00:10:48 PST 2018 * Mon Jan 1 01:08:31 PST 2018 * Mon Jan 1 14:45:44 PST 2018 * Mon Jan 1 14:54:56 PST 2018 * Mon Jan 1 17:29:29 PST 2018 * switch to euclidean dists * Mon Jan 1 17:39:27 PST 2018 * Mon Jan 1 17:41:47 PST 2018 * Mon Jan 1 17:44:18 PST 2018 * Mon Jan 1 17:47:09 PST 2018 * Mon Jan 1 20:31:02 PST 2018 * Mon Jan 1 20:39:33 PST 2018 * Mon Jan 1 20:40:55 PST 2018 * Mon Jan 1 20:55:06 PST 2018 * Mon Jan 1 21:05:52 PST 2018 * fix env path * merge richards fix * fix hash * Mon Jan 1 22:04:00 PST 2018 * Mon Jan 1 22:25:29 PST 2018 * Mon Jan 1 22:30:42 PST 2018 * simplified reward function * add framestack * add env configs * simplify speed reward * Tue Jan 2 17:36:15 PST 2018 * Tue Jan 2 17:49:16 PST 2018 * Tue Jan 2 18:10:38 PST 2018 * add lane keeping simple mode * Tue Jan 2 20:25:26 PST 2018 * Tue Jan 2 20:30:30 PST 2018 * Tue Jan 2 20:33:26 PST 2018 * Tue Jan 2 20:41:42 PST 2018 * ppo lane keep * simplify discrete actions * Tue Jan 2 21:41:05 PST 2018 * Tue Jan 2 21:49:03 PST 2018 * Tue Jan 2 22:12:23 PST 2018 * Tue Jan 2 22:14:42 PST 2018 * Tue Jan 2 22:20:59 PST 2018 * Tue Jan 2 22:23:43 PST 2018 * Tue Jan 2 22:26:27 PST 2018 * Tue Jan 2 22:27:20 PST 2018 * Tue Jan 2 22:44:00 PST 2018 * Tue Jan 2 22:57:58 PST 2018 * Tue Jan 2 23:08:51 PST 2018 * Tue Jan 2 23:11:32 PST 2018 * update dqn reward * Thu Jan 4 12:29:40 PST 2018 * Thu Jan 4 12:30:26 PST 2018 * Update train_dqn.py * fix
2018-01-05 21:32:41 -08:00
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.lstm_v1 import LSTM
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
from ray.rllib.models.tf.tf_action_dist import Categorical, \
Deterministic, DiagGaussian, Dirichlet, \
MultiActionDistribution, MultiCategorical
2019-07-03 15:59:47 -07:00
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.utils import try_import_tf, try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.error import UnsupportedSpaceException
2020-05-27 10:21:30 +02:00
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import flatten_space
tf = try_import_tf()
tree = try_import_tree()
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
MODEL_DEFAULTS = {
# === Built-in options ===
# Filter config. List of [out_channels, kernel, stride] for each filter
"conv_filters": None,
# Nonlinearity for built-in convnet
"conv_activation": "relu",
# Nonlinearity for fully connected net (tanh, relu)
"fcnet_activation": "tanh",
# Number of hidden layers for fully connected net
"fcnet_hiddens": [256, 256],
# For DiagGaussian action distributions, make the second half of the model
# outputs floating bias variables instead of state-dependent. This only
# has an effect is using the default fully connected net.
"free_log_std": False,
2019-07-03 15:59:47 -07:00
# Whether to skip the final linear layer used to resize the hidden layer
# outputs to size `num_outputs`. If True, then the last hidden layer
# should already match num_outputs.
"no_final_linear": False,
# Whether layers should be shared for the value function.
"vf_share_layers": True,
# == LSTM ==
# Whether to wrap the model with a LSTM
"use_lstm": False,
# Max seq len for training the LSTM, defaults to 20
"max_seq_len": 20,
# Size of the LSTM cell
"lstm_cell_size": 256,
# Whether to feed a_{t-1}, r_{t-1} to LSTM
"lstm_use_prev_action_reward": False,
2019-07-03 15:59:47 -07:00
# When using modelv1 models with a modelv2 algorithm, you may have to
# define the state shape here (e.g., [256, 256]).
"state_shape": None,
# == Atari ==
# Whether to enable framestack for Atari envs
"framestack": True,
# Final resized frame dimension
"dim": 84,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
"grayscale": False,
# (deprecated) Changes frame to range from [-1, 1] if true
"zero_mean": True,
# === Options for custom models ===
# Name of a custom model to use
"custom_model": None,
# Extra options to pass to the custom classes.
# These will be available in the Model's
"custom_model_config": {},
# Name of a custom action distribution to use.
"custom_action_dist": None,
# Custom preprocessors are deprecated. Please use a wrapper class around
# your environment instead to preprocess observations.
"custom_preprocessor": None,
# Deprecated config keys.
"custom_options": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# yapf: enable
@PublicAPI
class ModelCatalog:
"""Registry of models, preprocessors, and action distributions for envs.
Examples:
>>> prep = ModelCatalog.get_preprocessor(env)
>>> observation = prep.transform(raw_observation)
>>> dist_class, dist_dim = ModelCatalog.get_action_dist(
env.action_space, {})
>>> model = ModelCatalog.get_model(inputs, dist_dim, options)
>>> dist = dist_class(model.outputs, model)
>>> action = dist.sample()
"""
@staticmethod
@DeveloperAPI
def get_action_dist(action_space,
config,
dist_type=None,
framework="tf",
**kwargs):
"""Returns a distribution class and size for the given action space.
Args:
action_space (Space): Action space of the target gym env.
config (Optional[dict]): Optional model config.
dist_type (Optional[str]): Identifier of the action distribution.
framework (str): One of "tf" or "torch".
kwargs (dict): Optional kwargs to pass on to the Distribution's
constructor.
Returns:
dist_class (ActionDistribution): Python class of the distribution.
dist_dim (int): The size of the input vector to the distribution.
"""
dist = None
config = config or MODEL_DEFAULTS
# Custom distribution given.
if config.get("custom_action_dist"):
action_dist_name = config["custom_action_dist"]
logger.debug(
"Using custom action distribution {}".format(action_dist_name))
dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
# Dist_type is given directly as a class.
elif type(dist_type) is type and \
issubclass(dist_type, ActionDistribution) and \
dist_type not in (
MultiActionDistribution, TorchMultiActionDistribution):
dist = dist_type
# Box space -> DiagGaussian OR Deterministic.
elif isinstance(action_space, gym.spaces.Box):
if len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space has multiple dimensions "
"{}. ".format(action_space.shape) +
"Consider reshaping this into a single dimension, "
"using a custom action distribution, "
"using a Tuple action space, or the multi-agent API.")
# TODO(sven): Check for bounds and return SquashedNormal, etc..
if dist_type is None:
dist = DiagGaussian if framework == "tf" else TorchDiagGaussian
elif dist_type == "deterministic":
dist = Deterministic if framework == "tf" else \
TorchDeterministic
# Discrete Space -> Categorical.
elif isinstance(action_space, gym.spaces.Discrete):
dist = Categorical if framework == "tf" else TorchCategorical
# Tuple/Dict Spaces -> MultiAction.
elif dist_type in (MultiActionDistribution,
TorchMultiActionDistribution) or \
isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
flat_action_space = flatten_space(action_space)
child_dists_and_in_lens = tree.map_structure(
lambda s: ModelCatalog.get_action_dist(
s, config, framework=framework), flat_action_space)
child_dists = [e[0] for e in child_dists_and_in_lens]
input_lens = [int(e[1]) for e in child_dists_and_in_lens]
return partial(
(TorchMultiActionDistribution
if framework == "torch" else MultiActionDistribution),
action_space=action_space,
child_distributions=child_dists,
input_lens=input_lens), int(sum(input_lens))
# Simplex -> Dirichlet.
elif isinstance(action_space, Simplex):
if framework == "torch":
# TODO(sven): implement
raise NotImplementedError(
"Simplex action spaces not supported for torch.")
dist = Dirichlet
# MultiDiscrete -> MultiCategorical.
elif isinstance(action_space, gym.spaces.MultiDiscrete):
2020-03-04 09:41:40 +01:00
dist = MultiCategorical if framework == "tf" else \
TorchMultiCategorical
return partial(dist, input_lens=action_space.nvec), \
int(sum(action_space.nvec))
# Unknown type -> Error.
else:
raise NotImplementedError("Unsupported args: {} {}".format(
action_space, dist_type))
return dist, dist.required_model_output_shape(action_space, config)
@staticmethod
@DeveloperAPI
def get_action_shape(action_space):
"""Returns action tensor dtype and shape for the action space.
Args:
action_space (Space): Action space of the target gym env.
Returns:
(dtype, shape): Dtype and shape of the actions tensor.
"""
if isinstance(action_space, gym.spaces.Discrete):
return (tf.int64, (None, ))
elif isinstance(action_space, (gym.spaces.Box, Simplex)):
return (tf.float32, (None, ) + action_space.shape)
elif isinstance(action_space, gym.spaces.MultiDiscrete):
return (tf.as_dtype(action_space.dtype),
(None, ) + action_space.shape)
elif isinstance(action_space, (gym.spaces.Tuple, gym.spaces.Dict)):
flat_action_space = flatten_space(action_space)
size = 0
all_discrete = True
for i in range(len(flat_action_space)):
if isinstance(flat_action_space[i], gym.spaces.Discrete):
size += 1
else:
all_discrete = False
size += np.product(flat_action_space[i].shape)
size = int(size)
return (tf.int64 if all_discrete else tf.float32, (None, size))
else:
raise NotImplementedError(
"Action space {} not supported".format(action_space))
@staticmethod
@DeveloperAPI
def get_action_placeholder(action_space, name="action"):
"""Returns an action placeholder consistent with the action space
Args:
action_space (Space): Action space of the target gym env.
name (str): An optional string to name the placeholder by.
Default: "action".
Returns:
action_placeholder (Tensor): A placeholder for the actions
"""
dtype, shape = ModelCatalog.get_action_shape(action_space)
return tf.placeholder(dtype, shape=shape, name=name)
2019-07-03 15:59:47 -07:00
@staticmethod
@DeveloperAPI
2019-07-03 15:59:47 -07:00
def get_model_v2(obs_space,
action_space,
num_outputs,
model_config,
framework="tf",
name="default_model",
2019-07-03 15:59:47 -07:00
model_interface=None,
default_model=None,
2019-07-03 15:59:47 -07:00
**model_kwargs):
"""Returns a suitable model compatible with given spaces and output.
Args:
obs_space (Space): Observation space of the target gym env. This
may have an `original_space` attribute that specifies how to
unflatten the tensor into a ragged tensor.
action_space (Space): Action space of the target gym env.
num_outputs (int): The size of the output vector of the model.
framework (str): One of "tf" or "torch".
2019-07-03 15:59:47 -07:00
name (str): Name (scope) for the model.
model_interface (cls): Interface required for the model
default_model (cls): Override the default class for the model. This
only has an effect when not using a custom model
2019-07-03 15:59:47 -07:00
model_kwargs (dict): args to pass to the ModelV2 constructor
Returns:
model (ModelV2): Model to use for the policy.
"""
if model_config.get("custom_model"):
if "custom_options" in model_config and \
model_config["custom_options"] != DEPRECATED_VALUE:
deprecation_warning(
"model.custom_options",
"model.custom_model_config",
error=False)
model_config["custom_model_config"] = \
model_config.pop("custom_options")
if isinstance(model_config["custom_model"], type):
model_cls = model_config["custom_model"]
else:
model_cls = _global_registry.get(RLLIB_MODEL,
model_config["custom_model"])
# TODO(sven): Hard-deprecate Model(V1).
2019-07-03 15:59:47 -07:00
if issubclass(model_cls, ModelV2):
logger.info("Wrapping {} as {}".format(model_cls,
model_interface))
model_cls = ModelCatalog._wrap_if_needed(
model_cls, model_interface)
if framework == "tf":
# Track and warn if vars were created but not registered.
created = set()
def track_var_creation(next_creator, **kw):
v = next_creator(**kw)
created.add(v)
return v
with tf.variable_creator_scope(track_var_creation):
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name, **model_kwargs)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name)
# Other error -> re-raise.
else:
raise e
registered = set(instance.variables())
not_registered = set()
for var in created:
if var not in registered:
not_registered.add(var)
if not_registered:
raise ValueError(
"It looks like variables {} were created as part "
"of {} but does not appear in model.variables() "
"({}). Did you forget to call "
"model.register_variables() on the variables in "
"question?".format(not_registered, instance,
registered))
else:
# PyTorch automatically tracks nn.Modules inside the parent
# nn.Module's constructor.
# TODO(sven): Do this for TF as well.
2019-07-03 15:59:47 -07:00
instance = model_cls(obs_space, action_space, num_outputs,
model_config, name, **model_kwargs)
return instance
# TODO(sven): Hard-deprecate Model(V1). This check will be
# superflous then.
elif tf.executing_eagerly():
raise ValueError(
"Eager execution requires a TFModelV2 model to be "
"used, however you specified a custom model {}".format(
model_cls))
2019-07-03 15:59:47 -07:00
if framework == "tf":
v2_class = None
# try to get a default v2 model
if not model_config.get("custom_model"):
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
v2_class = default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
# fallback to a default v1 model
if v2_class is None:
if tf.executing_eagerly():
raise ValueError(
"Eager execution requires a TFModelV2 model to be "
"used, however there is no default V2 model for this "
"observation space: {}, use_lstm={}".format(
obs_space, model_config.get("use_lstm")))
v2_class = make_v1_wrapper(ModelCatalog.get_model)
# wrap in the requested interface
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
2019-07-03 15:59:47 -07:00
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
elif framework == "torch":
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
v2_class = \
default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
# Wrap in the requested interface.
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
else:
raise NotImplementedError(
"Framework must be 'tf' or 'torch': {}".format(framework))
2019-07-03 15:59:47 -07:00
@staticmethod
@DeveloperAPI
def get_preprocessor(env, options=None):
"""Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space().
"""
return ModelCatalog.get_preprocessor_for_space(env.observation_space,
options)
@staticmethod
@DeveloperAPI
def get_preprocessor_for_space(observation_space, options=None):
"""Returns a suitable preprocessor for the given observation space.
Args:
observation_space (Space): The input observation space.
options (dict): Options to pass to the preprocessor.
Returns:
preprocessor (Preprocessor): Preprocessor for the observations.
"""
options = options or MODEL_DEFAULTS
for k in options.keys():
if k not in MODEL_DEFAULTS:
raise Exception("Unknown config key `{}`, all keys: {}".format(
k, list(MODEL_DEFAULTS)))
if options.get("custom_preprocessor"):
preprocessor = options["custom_preprocessor"]
logger.info("Using custom preprocessor {}".format(preprocessor))
logger.warning(
"DeprecationWarning: Custom preprocessors are deprecated, "
"since they sometimes conflict with the built-in "
"preprocessors for handling complex observation spaces. "
"Please use wrapper classes around your environment "
"instead of preprocessors.")
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
observation_space, options)
else:
cls = get_preprocessor(observation_space)
prep = cls(observation_space, options)
logger.debug("Created preprocessor {}: {} -> {}".format(
prep, observation_space, prep.shape))
return prep
@staticmethod
@PublicAPI
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
"""Register a custom preprocessor class by name.
The preprocessor can be later used by specifying
{"custom_preprocessor": preprocesor_name} in the model config.
Args:
preprocessor_name (str): Name to register the preprocessor under.
preprocessor_class (type): Python class of the preprocessor.
"""
_global_registry.register(RLLIB_PREPROCESSOR, preprocessor_name,
preprocessor_class)
@staticmethod
@PublicAPI
def register_custom_model(model_name, model_class):
"""Register a custom model class by name.
The model can be later used by specifying {"custom_model": model_name}
in the model config.
Args:
model_name (str): Name to register the model under.
model_class (type): Python class of the model.
"""
_global_registry.register(RLLIB_MODEL, model_name, model_class)
@staticmethod
@PublicAPI
def register_custom_action_dist(action_dist_name, action_dist_class):
"""Register a custom action distribution class by name.
The model can be later used by specifying
{"custom_action_dist": action_dist_name} in the model config.
Args:
model_name (str): Name to register the action distribution under.
model_class (type): Python class of the action distribution.
"""
_global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
action_dist_class)
2019-07-03 15:59:47 -07:00
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
assert issubclass(model_cls, (TFModelV2, TorchModelV2)), model_cls
2019-07-03 15:59:47 -07:00
if not model_interface or issubclass(model_cls, model_interface):
return model_cls
class wrapper(model_interface, model_cls):
pass
name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
wrapper.__name__ = name
wrapper.__qualname__ = name
return wrapper
@staticmethod
def get_model(input_dict,
obs_space,
2019-03-10 04:23:12 +01:00
action_space,
num_outputs,
options,
state_in=None,
seq_lens=None):
"""Deprecated: use get_model_v2() instead."""
deprecation_warning("get_model", "get_model_v2", error=False)
assert isinstance(input_dict, dict)
options = options or MODEL_DEFAULTS
2019-03-10 04:23:12 +01:00
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
num_outputs, options, state_in,
seq_lens)
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
if options.get("use_lstm"):
copy = dict(input_dict)
copy["obs"] = model.last_layer
feature_space = gym.spaces.Box(
-1, 1, shape=(model.last_layer.shape[1], ))
2019-03-10 04:23:12 +01:00
model = LSTM(copy, feature_space, action_space, num_outputs,
options, state_in, seq_lens)
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
2019-03-10 04:23:12 +01:00
logger.debug(
"Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
model, input_dict, obs_space, action_space, state_in, seq_lens,
model.outputs, model.state_out))
model._validate_output_shape()
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
return model
@staticmethod
2019-03-10 04:23:12 +01:00
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
state_in, seq_lens):
deprecation_warning("_get_model", "get_model_v2", error=False)
if options.get("custom_model"):
model = options["custom_model"]
logger.debug("Using custom model {}".format(model))
return _global_registry.get(RLLIB_MODEL, model)(
input_dict,
obs_space,
2019-03-10 04:23:12 +01:00
action_space,
num_outputs,
options,
state_in=state_in,
seq_lens=seq_lens)
obs_rank = len(input_dict["obs"].shape) - 1 # drops batch dim
if obs_rank > 2:
2019-03-10 04:23:12 +01:00
return VisionNetwork(input_dict, obs_space, action_space,
num_outputs, options)
2019-03-10 04:23:12 +01:00
return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@staticmethod
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
def _get_v2_model_class(obs_space, model_config, framework="tf"):
model_config = model_config or MODEL_DEFAULTS
if framework == "torch":
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
FCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
VisionNet)
else:
from ray.rllib.models.tf.fcnet import \
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
FullyConnectedNetwork as FCNet
from ray.rllib.models.tf.visionnet import \
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
VisionNetwork as VisionNet
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
# Discrete/1D obs-spaces.
if isinstance(obs_space, gym.spaces.Discrete) or \
len(obs_space.shape) <= 2:
return FCNet
# Default Conv2D net.
else:
return VisionNet
@staticmethod
def get_torch_model(obs_space,
num_outputs,
options=None,
default_model_cls=None):
raise DeprecationWarning("Please use get_model_v2() instead.")