2020-04-06 20:56:16 +02:00
|
|
|
from functools import partial
|
2017-07-17 01:58:54 -07:00
|
|
|
import gym
|
2021-01-13 08:53:34 +01:00
|
|
|
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-01-18 19:51:31 -08:00
|
|
|
import numpy as np
|
2021-04-16 09:16:24 +02:00
|
|
|
import tree # pip install dm_tree
|
2020-10-12 22:50:43 +02:00
|
|
|
from typing import List, Optional, Type, Union
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.tune.registry import (
|
|
|
|
RLLIB_MODEL,
|
|
|
|
RLLIB_PREPROCESSOR,
|
|
|
|
RLLIB_ACTION_DIST,
|
|
|
|
_global_registry,
|
|
|
|
)
|
2020-02-19 21:18:45 +01:00
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-07-24 12:01:46 -07:00
|
|
|
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import (
|
|
|
|
Categorical,
|
|
|
|
Deterministic,
|
|
|
|
DiagGaussian,
|
|
|
|
Dirichlet,
|
|
|
|
MultiActionDistribution,
|
|
|
|
MultiCategorical,
|
|
|
|
)
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import (
|
|
|
|
TorchCategorical,
|
|
|
|
TorchDeterministic,
|
|
|
|
TorchDiagGaussian,
|
|
|
|
TorchMultiActionDistribution,
|
|
|
|
TorchMultiCategorical,
|
|
|
|
)
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.utils.deprecation import (
|
|
|
|
Deprecated,
|
|
|
|
DEPRECATED_VALUE,
|
|
|
|
deprecation_warning,
|
|
|
|
)
|
2019-07-07 15:06:41 -07:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2021-01-11 13:19:46 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
2020-05-27 10:21:30 +02:00
|
|
|
from ray.rllib.utils.spaces.simplex import Simplex
|
|
|
|
from ray.rllib.utils.spaces.space_utils import flatten_space
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2021-01-11 13:19:46 +01:00
|
|
|
torch, _ = try_import_torch()
|
2017-07-17 01:58:54 -07:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: off
|
2018-10-16 15:55:11 -07:00
|
|
|
# __sphinx_doc_begin__
|
2020-07-24 12:01:46 -07:00
|
|
|
MODEL_DEFAULTS: ModelConfigDict = {
|
2021-04-27 10:44:54 +02:00
|
|
|
# Experimental flag.
|
|
|
|
# If True, try to use a native (tf.keras.Model or torch.Module) default
|
|
|
|
# model instead of our built-in ModelV2 defaults.
|
|
|
|
# If False (default), use "classic" ModelV2 default models.
|
2021-04-30 19:26:30 +02:00
|
|
|
# Note that this currently only works for:
|
|
|
|
# 1) framework != torch AND
|
|
|
|
# 2) fully connected and CNN default networks as well as
|
|
|
|
# auto-wrapped LSTM- and attention nets.
|
2021-04-27 10:44:54 +02:00
|
|
|
"_use_default_native_models": False,
|
2021-09-09 08:10:42 +02:00
|
|
|
# Experimental flag.
|
|
|
|
# If True, user specified no preprocessor to be created
|
2021-09-23 12:56:45 +02:00
|
|
|
# (via config._disable_preprocessor_api=True). If True, observations
|
|
|
|
# will arrive in model as they are returned by the env.
|
|
|
|
"_disable_preprocessor_api": False,
|
2022-01-05 11:29:44 +01:00
|
|
|
# Experimental flag.
|
|
|
|
# If True, RLlib will no longer flatten the policy-computed actions into
|
|
|
|
# a single tensor (for storage in SampleCollectors/output files/etc..),
|
|
|
|
# but leave (possibly nested) actions as-is. Disabling flattening affects:
|
|
|
|
# - SampleCollectors: Have to store possibly nested action structs.
|
|
|
|
# - Models that have the previous action(s) as part of their input.
|
|
|
|
# - Algorithms reading from offline files (incl. action information).
|
|
|
|
"_disable_action_flattening": False,
|
2021-04-27 10:44:54 +02:00
|
|
|
|
2017-12-28 13:19:04 -08:00
|
|
|
# === Built-in options ===
|
2021-01-08 10:56:09 +01:00
|
|
|
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
|
|
|
|
# These are used if no custom model is specified and the input space is 1D.
|
|
|
|
# Number of hidden layers to be used.
|
2020-08-19 17:49:50 +02:00
|
|
|
"fcnet_hiddens": [256, 256],
|
2021-01-08 10:56:09 +01:00
|
|
|
# Activation function descriptor.
|
|
|
|
# Supported values are: "tanh", "relu", "swish" (or "silu"),
|
|
|
|
# "linear" (or None).
|
2020-08-19 17:49:50 +02:00
|
|
|
"fcnet_activation": "tanh",
|
2021-01-08 10:56:09 +01:00
|
|
|
|
|
|
|
# VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
|
|
|
|
# These are used if no custom model is specified and the input space is 2D.
|
|
|
|
# Filter config: List of [out_channels, kernel, stride] for each filter.
|
|
|
|
# Example:
|
|
|
|
# Use None for making RLlib try to find a default filter setup given the
|
|
|
|
# observation space.
|
2018-10-16 15:55:11 -07:00
|
|
|
"conv_filters": None,
|
2021-01-08 10:56:09 +01:00
|
|
|
# Activation function descriptor.
|
|
|
|
# Supported values are: "tanh", "relu", "swish" (or "silu"),
|
|
|
|
# "linear" (or None).
|
2018-10-16 15:55:11 -07:00
|
|
|
"conv_activation": "relu",
|
2021-01-08 10:56:09 +01:00
|
|
|
|
2021-02-02 13:05:58 +01:00
|
|
|
# Some default models support a final FC stack of n Dense layers with given
|
|
|
|
# activation:
|
|
|
|
# - Complex observation spaces: Image components are fed through
|
|
|
|
# VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
|
|
|
|
# everything is concated and pushed through this final FC stack.
|
|
|
|
# - VisionNets (CNNs), e.g. after the CNN stack, there may be
|
|
|
|
# additional Dense layers.
|
|
|
|
# - FullyConnectedNetworks will have this additional FCStack as well
|
|
|
|
# (that's why it's empty by default).
|
|
|
|
"post_fcnet_hiddens": [],
|
|
|
|
"post_fcnet_activation": "relu",
|
|
|
|
|
2020-05-12 10:14:05 -07:00
|
|
|
# For DiagGaussian action distributions, make the second half of the model
|
|
|
|
# outputs floating bias variables instead of state-dependent. This only
|
|
|
|
# has an effect is using the default fully connected net.
|
2018-10-16 15:55:11 -07:00
|
|
|
"free_log_std": False,
|
2019-07-03 15:59:47 -07:00
|
|
|
# Whether to skip the final linear layer used to resize the hidden layer
|
|
|
|
# outputs to size `num_outputs`. If True, then the last hidden layer
|
|
|
|
# should already match num_outputs.
|
|
|
|
"no_final_linear": False,
|
|
|
|
# Whether layers should be shared for the value function.
|
2019-07-07 15:06:41 -07:00
|
|
|
"vf_share_layers": True,
|
2018-10-16 15:55:11 -07:00
|
|
|
|
|
|
|
# == LSTM ==
|
2020-06-05 15:40:30 +02:00
|
|
|
# Whether to wrap the model with an LSTM.
|
2018-10-16 15:55:11 -07:00
|
|
|
"use_lstm": False,
|
2020-06-05 15:40:30 +02:00
|
|
|
# Max seq len for training the LSTM, defaults to 20.
|
2018-10-16 15:55:11 -07:00
|
|
|
"max_seq_len": 20,
|
2020-06-05 15:40:30 +02:00
|
|
|
# Size of the LSTM cell.
|
2018-10-16 15:55:11 -07:00
|
|
|
"lstm_cell_size": 256,
|
2020-11-25 20:27:46 +01:00
|
|
|
# Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
|
|
|
|
"lstm_use_prev_action": False,
|
|
|
|
# Whether to feed r_{t-1} to LSTM.
|
|
|
|
"lstm_use_prev_reward": False,
|
2020-08-21 12:35:16 +02:00
|
|
|
# Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
|
|
|
|
"_time_major": False,
|
2018-10-16 15:55:11 -07:00
|
|
|
|
2021-01-01 14:06:23 -05:00
|
|
|
# == Attention Nets (experimental: torch-version is untested) ==
|
|
|
|
# Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
|
|
|
|
# wrapper Model around the default Model.
|
|
|
|
"use_attention": False,
|
|
|
|
# The number of transformer units within GTrXL.
|
|
|
|
# A transformer unit in GTrXL consists of a) MultiHeadAttention module and
|
|
|
|
# b) a position-wise MLP.
|
|
|
|
"attention_num_transformer_units": 1,
|
|
|
|
# The input and output size of each transformer unit.
|
|
|
|
"attention_dim": 64,
|
|
|
|
# The number of attention heads within the MultiHeadAttention units.
|
|
|
|
"attention_num_heads": 1,
|
|
|
|
# The dim of a single head (within the MultiHeadAttention units).
|
|
|
|
"attention_head_dim": 32,
|
|
|
|
# The memory sizes for inference and training.
|
|
|
|
"attention_memory_inference": 50,
|
|
|
|
"attention_memory_training": 50,
|
|
|
|
# The output dim of the position-wise MLP.
|
|
|
|
"attention_position_wise_mlp_dim": 32,
|
|
|
|
# The initial bias values for the 2 GRU gates within a transformer unit.
|
|
|
|
"attention_init_gru_gate_bias": 2.0,
|
2021-03-12 18:27:25 +01:00
|
|
|
# Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
|
|
|
|
"attention_use_n_prev_actions": 0,
|
2021-01-01 14:06:23 -05:00
|
|
|
# Whether to feed r_{t-n:t-1} to GTrXL.
|
2021-03-12 18:27:25 +01:00
|
|
|
"attention_use_n_prev_rewards": 0,
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2018-10-16 15:55:11 -07:00
|
|
|
# == Atari ==
|
2021-09-03 13:29:57 +02:00
|
|
|
# Set to True to enable 4x stacking behavior.
|
|
|
|
"framestack": True,
|
2018-10-16 15:55:11 -07:00
|
|
|
# Final resized frame dimension
|
|
|
|
"dim": 84,
|
|
|
|
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
|
|
|
|
"grayscale": False,
|
|
|
|
# (deprecated) Changes frame to range from [-1, 1] if true
|
|
|
|
"zero_mean": True,
|
2017-12-28 13:19:04 -08:00
|
|
|
|
|
|
|
# === Options for custom models ===
|
2018-10-16 15:55:11 -07:00
|
|
|
# Name of a custom model to use
|
|
|
|
"custom_model": None,
|
2020-08-17 11:12:20 -07:00
|
|
|
# Extra options to pass to the custom classes. These will be available to
|
|
|
|
# the Model's constructor in the model_config field. Also, they will be
|
|
|
|
# attempted to be passed as **kwargs to ModelV2 models. For an example,
|
|
|
|
# see rllib/models/[tf|torch]/attention_net.py.
|
2020-05-27 10:19:47 +02:00
|
|
|
"custom_model_config": {},
|
2020-02-19 21:18:45 +01:00
|
|
|
# Name of a custom action distribution to use.
|
2019-08-06 18:13:16 +00:00
|
|
|
"custom_action_dist": None,
|
2020-01-18 23:30:09 -08:00
|
|
|
# Custom preprocessors are deprecated. Please use a wrapper class around
|
|
|
|
# your environment instead to preprocess observations.
|
|
|
|
"custom_preprocessor": None,
|
2020-11-25 20:27:46 +01:00
|
|
|
|
|
|
|
# Deprecated keys:
|
|
|
|
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
|
|
|
|
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
|
2018-10-16 15:55:11 -07:00
|
|
|
}
|
|
|
|
# __sphinx_doc_end__
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: on
|
2017-09-02 17:20:56 -07:00
|
|
|
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
@PublicAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class ModelCatalog:
|
2018-01-01 11:10:44 -08:00
|
|
|
"""Registry of models, preprocessors, and action distributions for envs.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> prep = ModelCatalog.get_preprocessor(env)
|
|
|
|
>>> observation = prep.transform(raw_observation)
|
|
|
|
|
2019-08-10 14:05:12 -07:00
|
|
|
>>> dist_class, dist_dim = ModelCatalog.get_action_dist(
|
2020-06-05 15:40:30 +02:00
|
|
|
... env.action_space, {})
|
|
|
|
>>> model = ModelCatalog.get_model_v2(
|
|
|
|
... obs_space, action_space, num_outputs, options)
|
2019-08-10 14:05:12 -07:00
|
|
|
>>> dist = dist_class(model.outputs, model)
|
2018-01-01 11:10:44 -08:00
|
|
|
>>> action = dist.sample()
|
|
|
|
"""
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
@staticmethod
|
2019-01-23 21:27:26 -08:00
|
|
|
@DeveloperAPI
|
2020-10-12 22:50:43 +02:00
|
|
|
def get_action_dist(
|
2022-01-29 18:41:57 -08:00
|
|
|
action_space: gym.Space,
|
|
|
|
config: ModelConfigDict,
|
|
|
|
dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
|
|
|
|
framework: str = "tf",
|
|
|
|
**kwargs
|
|
|
|
) -> (type, int):
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Returns a distribution class and size for the given action space.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
2020-02-19 21:18:45 +01:00
|
|
|
config (Optional[dict]): Optional model config.
|
2020-10-12 22:50:43 +02:00
|
|
|
dist_type (Optional[Union[str, Type[ActionDistribution]]]):
|
|
|
|
Identifier of the action distribution (str) interpreted as a
|
|
|
|
hint or the actual ActionDistribution class to use.
|
2020-12-03 15:51:30 +01:00
|
|
|
framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
|
2020-02-19 21:18:45 +01:00
|
|
|
kwargs (dict): Optional kwargs to pass on to the Distribution's
|
|
|
|
constructor.
|
2017-07-17 01:58:54 -07:00
|
|
|
|
|
|
|
Returns:
|
2020-06-16 09:01:20 +02:00
|
|
|
Tuple:
|
|
|
|
- dist_class (ActionDistribution): Python class of the
|
|
|
|
distribution.
|
|
|
|
- dist_dim (int): The size of the input vector to the
|
|
|
|
distribution.
|
2017-07-17 01:58:54 -07:00
|
|
|
"""
|
2020-05-27 16:19:13 +02:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
dist_cls = None
|
2018-10-16 15:55:11 -07:00
|
|
|
config = config or MODEL_DEFAULTS
|
2020-02-19 21:18:45 +01:00
|
|
|
# Custom distribution given.
|
2019-08-06 18:13:16 +00:00
|
|
|
if config.get("custom_action_dist"):
|
2021-01-28 19:28:48 +01:00
|
|
|
custom_action_config = config.copy()
|
|
|
|
action_dist_name = custom_action_config.pop("custom_action_dist")
|
2022-01-29 18:41:57 -08:00
|
|
|
logger.debug("Using custom action distribution {}".format(action_dist_name))
|
|
|
|
dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
|
2021-01-25 11:42:39 +00:00
|
|
|
return ModelCatalog._get_multi_action_distribution(
|
2022-01-29 18:41:57 -08:00
|
|
|
dist_cls, action_space, custom_action_config, framework
|
|
|
|
)
|
2020-10-12 22:50:43 +02:00
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
# Dist_type is given directly as a class.
|
2022-01-29 18:41:57 -08:00
|
|
|
elif (
|
|
|
|
type(dist_type) is type
|
|
|
|
and issubclass(dist_type, ActionDistribution)
|
|
|
|
and dist_type not in (MultiActionDistribution, TorchMultiActionDistribution)
|
|
|
|
):
|
2020-10-06 20:28:16 +02:00
|
|
|
dist_cls = dist_type
|
2020-02-19 21:18:45 +01:00
|
|
|
# Box space -> DiagGaussian OR Deterministic.
|
2021-01-13 08:53:34 +01:00
|
|
|
elif isinstance(action_space, Box):
|
2021-04-11 13:16:01 +02:00
|
|
|
if action_space.dtype.name.startswith("int"):
|
|
|
|
low_ = np.min(action_space.low)
|
|
|
|
high_ = np.max(action_space.high)
|
2022-01-29 18:41:57 -08:00
|
|
|
dist_cls = (
|
|
|
|
TorchMultiCategorical if framework == "torch" else MultiCategorical
|
|
|
|
)
|
2021-04-11 13:16:01 +02:00
|
|
|
num_cats = int(np.product(action_space.shape))
|
2022-01-29 18:41:57 -08:00
|
|
|
return (
|
|
|
|
partial(
|
|
|
|
dist_cls,
|
|
|
|
input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
|
|
|
|
action_space=action_space,
|
|
|
|
),
|
|
|
|
num_cats * (high_ - low_ + 1),
|
|
|
|
)
|
2021-04-11 13:16:01 +02:00
|
|
|
else:
|
|
|
|
if len(action_space.shape) > 1:
|
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space has multiple dimensions "
|
2022-01-29 18:41:57 -08:00
|
|
|
"{}. ".format(action_space.shape)
|
|
|
|
+ "Consider reshaping this into a single dimension, "
|
2021-04-11 13:16:01 +02:00
|
|
|
"using a custom action distribution, "
|
2022-01-29 18:41:57 -08:00
|
|
|
"using a Tuple action space, or the multi-agent API."
|
|
|
|
)
|
2021-04-11 13:16:01 +02:00
|
|
|
# TODO(sven): Check for bounds and return SquashedNormal, etc..
|
|
|
|
if dist_type is None:
|
2022-01-29 18:41:57 -08:00
|
|
|
return (
|
|
|
|
partial(
|
|
|
|
TorchDiagGaussian if framework == "torch" else DiagGaussian,
|
|
|
|
action_space=action_space,
|
|
|
|
),
|
|
|
|
DiagGaussian.required_model_output_shape(action_space, config),
|
|
|
|
)
|
2021-04-11 13:16:01 +02:00
|
|
|
elif dist_type == "deterministic":
|
2022-01-29 18:41:57 -08:00
|
|
|
dist_cls = (
|
|
|
|
TorchDeterministic if framework == "torch" else Deterministic
|
|
|
|
)
|
2020-02-19 21:18:45 +01:00
|
|
|
# Discrete Space -> Categorical.
|
2021-01-13 08:53:34 +01:00
|
|
|
elif isinstance(action_space, Discrete):
|
2022-04-22 14:51:35 -07:00
|
|
|
if framework == "torch":
|
|
|
|
dist_cls = TorchCategorical
|
|
|
|
elif framework == "jax":
|
|
|
|
from ray.rllib.models.jax.jax_action_dist import JAXCategorical
|
|
|
|
|
|
|
|
dist_cls = JAXCategorical
|
|
|
|
else:
|
|
|
|
dist_cls = Categorical
|
2020-04-28 14:59:16 +02:00
|
|
|
# Tuple/Dict Spaces -> MultiAction.
|
2022-01-29 18:41:57 -08:00
|
|
|
elif (
|
|
|
|
dist_type
|
|
|
|
in (
|
|
|
|
MultiActionDistribution,
|
|
|
|
TorchMultiActionDistribution,
|
|
|
|
)
|
|
|
|
or isinstance(action_space, (Tuple, Dict))
|
|
|
|
):
|
2020-10-12 22:50:43 +02:00
|
|
|
return ModelCatalog._get_multi_action_distribution(
|
2022-01-29 18:41:57 -08:00
|
|
|
(
|
|
|
|
MultiActionDistribution
|
|
|
|
if framework == "tf"
|
|
|
|
else TorchMultiActionDistribution
|
|
|
|
),
|
|
|
|
action_space,
|
|
|
|
config,
|
|
|
|
framework,
|
|
|
|
)
|
2020-02-19 21:18:45 +01:00
|
|
|
# Simplex -> Dirichlet.
|
2019-02-17 04:44:59 +08:00
|
|
|
elif isinstance(action_space, Simplex):
|
2020-01-21 08:06:50 +01:00
|
|
|
if framework == "torch":
|
|
|
|
# TODO(sven): implement
|
|
|
|
raise NotImplementedError(
|
2022-01-29 18:41:57 -08:00
|
|
|
"Simplex action spaces not supported for torch."
|
|
|
|
)
|
2020-10-06 20:28:16 +02:00
|
|
|
dist_cls = Dirichlet
|
2020-02-19 21:18:45 +01:00
|
|
|
# MultiDiscrete -> MultiCategorical.
|
2021-01-13 08:53:34 +01:00
|
|
|
elif isinstance(action_space, MultiDiscrete):
|
2022-01-29 18:41:57 -08:00
|
|
|
dist_cls = (
|
|
|
|
TorchMultiCategorical if framework == "torch" else MultiCategorical
|
|
|
|
)
|
|
|
|
return partial(dist_cls, input_lens=action_space.nvec), int(
|
|
|
|
sum(action_space.nvec)
|
|
|
|
)
|
2020-02-19 21:18:45 +01:00
|
|
|
# Unknown type -> Error.
|
2020-01-21 08:06:50 +01:00
|
|
|
else:
|
2022-01-29 18:41:57 -08:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Unsupported args: {} {}".format(action_space, dist_type)
|
|
|
|
)
|
2019-03-13 04:32:11 +01:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
return dist_cls, dist_cls.required_model_output_shape(action_space, config)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2018-01-18 19:51:31 -08:00
|
|
|
@staticmethod
|
2019-01-23 21:27:26 -08:00
|
|
|
@DeveloperAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
def get_action_shape(
|
|
|
|
action_space: gym.Space, framework: str = "tf"
|
|
|
|
) -> (np.dtype, List[int]):
|
2019-08-23 02:21:11 -04:00
|
|
|
"""Returns action tensor dtype and shape for the action space.
|
2018-01-18 19:51:31 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
2021-01-11 13:19:46 +01:00
|
|
|
framework (str): The framework identifier. One of "tf" or "torch".
|
|
|
|
|
2018-01-18 19:51:31 -08:00
|
|
|
Returns:
|
2019-08-23 02:21:11 -04:00
|
|
|
(dtype, shape): Dtype and shape of the actions tensor.
|
2018-01-18 19:51:31 -08:00
|
|
|
"""
|
2021-01-11 13:19:46 +01:00
|
|
|
dl_lib = torch if framework == "torch" else tf
|
2021-01-13 08:53:34 +01:00
|
|
|
if isinstance(action_space, Discrete):
|
2022-01-29 18:41:57 -08:00
|
|
|
return action_space.dtype, (None,)
|
2021-01-13 08:53:34 +01:00
|
|
|
elif isinstance(action_space, (Box, Simplex)):
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
if np.issubdtype(action_space.dtype, np.floating):
|
2022-01-29 18:41:57 -08:00
|
|
|
return dl_lib.float32, (None,) + action_space.shape
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
elif np.issubdtype(action_space.dtype, np.integer):
|
2022-01-29 18:41:57 -08:00
|
|
|
return dl_lib.int32, (None,) + action_space.shape
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
else:
|
2022-01-29 18:41:57 -08:00
|
|
|
raise ValueError("RLlib doesn't support non int or float box spaces")
|
2021-01-13 08:53:34 +01:00
|
|
|
elif isinstance(action_space, MultiDiscrete):
|
2022-01-29 18:41:57 -08:00
|
|
|
return action_space.dtype, (None,) + action_space.shape
|
2021-01-13 08:53:34 +01:00
|
|
|
elif isinstance(action_space, (Tuple, Dict)):
|
2020-04-28 14:59:16 +02:00
|
|
|
flat_action_space = flatten_space(action_space)
|
2018-01-18 19:51:31 -08:00
|
|
|
size = 0
|
2018-01-24 11:03:43 -08:00
|
|
|
all_discrete = True
|
2020-04-28 14:59:16 +02:00
|
|
|
for i in range(len(flat_action_space)):
|
2021-01-13 08:53:34 +01:00
|
|
|
if isinstance(flat_action_space[i], Discrete):
|
2018-01-24 11:03:43 -08:00
|
|
|
size += 1
|
|
|
|
else:
|
|
|
|
all_discrete = False
|
2020-04-28 14:59:16 +02:00
|
|
|
size += np.product(flat_action_space[i].shape)
|
2020-02-25 17:16:29 -05:00
|
|
|
size = int(size)
|
2022-01-29 18:41:57 -08:00
|
|
|
return dl_lib.int32 if all_discrete else dl_lib.float32, (None, size)
|
2018-01-18 19:51:31 -08:00
|
|
|
else:
|
2020-04-28 14:59:16 +02:00
|
|
|
raise NotImplementedError(
|
2022-01-29 18:41:57 -08:00
|
|
|
"Action space {} not supported".format(action_space)
|
|
|
|
)
|
2018-01-18 19:51:31 -08:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
@staticmethod
|
|
|
|
@DeveloperAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
def get_action_placeholder(
|
|
|
|
action_space: gym.Space, name: str = "action"
|
|
|
|
) -> TensorType:
|
2019-08-23 02:21:11 -04:00
|
|
|
"""Returns an action placeholder consistent with the action space
|
|
|
|
|
|
|
|
Args:
|
|
|
|
action_space (Space): Action space of the target gym env.
|
2020-04-28 14:59:16 +02:00
|
|
|
name (str): An optional string to name the placeholder by.
|
|
|
|
Default: "action".
|
2020-10-27 10:00:24 +01:00
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
Returns:
|
|
|
|
action_placeholder (Tensor): A placeholder for the actions
|
|
|
|
"""
|
2022-01-29 18:41:57 -08:00
|
|
|
dtype, shape = ModelCatalog.get_action_shape(action_space, framework="tf")
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
return tf1.placeholder(dtype, shape=shape, name=name)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
@staticmethod
|
2019-07-27 02:08:16 -07:00
|
|
|
@DeveloperAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
def get_model_v2(
|
|
|
|
obs_space: gym.Space,
|
|
|
|
action_space: gym.Space,
|
|
|
|
num_outputs: int,
|
|
|
|
model_config: ModelConfigDict,
|
|
|
|
framework: str = "tf",
|
|
|
|
name: str = "default_model",
|
|
|
|
model_interface: type = None,
|
|
|
|
default_model: type = None,
|
|
|
|
**model_kwargs
|
|
|
|
) -> ModelV2:
|
2019-07-03 15:59:47 -07:00
|
|
|
"""Returns a suitable model compatible with given spaces and output.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
obs_space (Space): Observation space of the target gym env. This
|
|
|
|
may have an `original_space` attribute that specifies how to
|
|
|
|
unflatten the tensor into a ragged tensor.
|
|
|
|
action_space (Space): Action space of the target gym env.
|
|
|
|
num_outputs (int): The size of the output vector of the model.
|
2021-01-01 14:06:23 -05:00
|
|
|
model_config (ModelConfigDict): The "model" sub-config dict
|
|
|
|
within the Trainer's config dict.
|
2020-12-03 15:51:30 +01:00
|
|
|
framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
|
2019-07-03 15:59:47 -07:00
|
|
|
name (str): Name (scope) for the model.
|
|
|
|
model_interface (cls): Interface required for the model
|
2019-07-24 13:55:55 -07:00
|
|
|
default_model (cls): Override the default class for the model. This
|
|
|
|
only has an effect when not using a custom model
|
2019-07-03 15:59:47 -07:00
|
|
|
model_kwargs (dict): args to pass to the ModelV2 constructor
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
model (ModelV2): Model to use for the policy.
|
|
|
|
"""
|
|
|
|
|
2021-01-01 14:06:23 -05:00
|
|
|
# Validate the given config dict.
|
2022-01-05 11:29:44 +01:00
|
|
|
ModelCatalog._validate_config(
|
2022-01-29 18:41:57 -08:00
|
|
|
config=model_config, action_space=action_space, framework=framework
|
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
if model_config.get("custom_model"):
|
2020-10-07 22:11:07 -04:00
|
|
|
# Allow model kwargs to be overridden / augmented by
|
2020-08-17 11:12:20 -07:00
|
|
|
# custom_model_config.
|
|
|
|
customized_model_kwargs = dict(
|
2022-01-29 18:41:57 -08:00
|
|
|
model_kwargs, **model_config.get("custom_model_config", {})
|
|
|
|
)
|
2020-08-17 11:12:20 -07:00
|
|
|
|
2020-05-18 17:26:40 +02:00
|
|
|
if isinstance(model_config["custom_model"], type):
|
|
|
|
model_cls = model_config["custom_model"]
|
|
|
|
else:
|
2022-01-29 18:41:57 -08:00
|
|
|
model_cls = _global_registry.get(
|
|
|
|
RLLIB_MODEL, model_config["custom_model"]
|
|
|
|
)
|
2020-06-05 15:40:30 +02:00
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
# Only allow ModelV2 or native keras Models.
|
2020-09-06 10:58:00 +02:00
|
|
|
if not issubclass(model_cls, ModelV2):
|
2022-01-29 18:41:57 -08:00
|
|
|
if framework not in ["tf", "tf2", "tfe"] or not issubclass(
|
|
|
|
model_cls, tf.keras.Model
|
|
|
|
):
|
2021-04-27 10:44:54 +02:00
|
|
|
raise ValueError(
|
|
|
|
"`model_cls` must be a ModelV2 sub-class, but is"
|
2022-01-29 18:41:57 -08:00
|
|
|
" {}!".format(model_cls)
|
|
|
|
)
|
2020-09-06 10:58:00 +02:00
|
|
|
|
|
|
|
logger.info("Wrapping {} as {}".format(model_cls, model_interface))
|
2022-01-29 18:41:57 -08:00
|
|
|
model_cls = ModelCatalog._wrap_if_needed(model_cls, model_interface)
|
2020-09-06 10:58:00 +02:00
|
|
|
|
2020-10-02 23:07:44 +02:00
|
|
|
if framework in ["tf2", "tf", "tfe"]:
|
2021-01-01 14:06:23 -05:00
|
|
|
# Try wrapping custom model with LSTM/attention, if required.
|
2022-01-29 18:41:57 -08:00
|
|
|
if model_config.get("use_lstm") or model_config.get("use_attention"):
|
|
|
|
from ray.rllib.models.tf.attention_net import (
|
|
|
|
AttentionWrapper,
|
|
|
|
Keras_AttentionWrapper,
|
|
|
|
)
|
|
|
|
from ray.rllib.models.tf.recurrent_net import (
|
|
|
|
LSTMWrapper,
|
|
|
|
Keras_LSTMWrapper,
|
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2020-12-27 09:46:03 -05:00
|
|
|
wrapped_cls = model_cls
|
2021-04-30 19:26:30 +02:00
|
|
|
# Wrapped (custom) model is itself a keras Model ->
|
|
|
|
# wrap with keras LSTM/GTrXL (attention) wrappers.
|
|
|
|
if issubclass(wrapped_cls, tf.keras.Model):
|
2022-01-29 18:41:57 -08:00
|
|
|
model_cls = (
|
|
|
|
Keras_LSTMWrapper
|
|
|
|
if model_config.get("use_lstm")
|
|
|
|
else Keras_AttentionWrapper
|
|
|
|
)
|
2021-04-30 19:26:30 +02:00
|
|
|
model_config["wrapped_cls"] = wrapped_cls
|
|
|
|
# Wrapped (custom) model is ModelV2 ->
|
|
|
|
# wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
|
|
|
|
else:
|
|
|
|
forward = wrapped_cls.forward
|
|
|
|
model_cls = ModelCatalog._wrap_if_needed(
|
2022-01-29 18:41:57 -08:00
|
|
|
wrapped_cls,
|
|
|
|
LSTMWrapper
|
|
|
|
if model_config.get("use_lstm")
|
|
|
|
else AttentionWrapper,
|
|
|
|
)
|
2021-04-30 19:26:30 +02:00
|
|
|
model_cls._wrapped_forward = forward
|
2020-12-27 09:46:03 -05:00
|
|
|
|
2021-01-14 14:44:33 +01:00
|
|
|
# Obsolete: Track and warn if vars were created but not
|
|
|
|
# registered. Only still do this, if users do register their
|
|
|
|
# variables. If not (which they shouldn't), don't check here.
|
2020-09-06 10:58:00 +02:00
|
|
|
created = set()
|
|
|
|
|
|
|
|
def track_var_creation(next_creator, **kw):
|
|
|
|
v = next_creator(**kw)
|
2022-04-29 03:07:02 -05:00
|
|
|
created.add(v.ref())
|
2020-09-06 10:58:00 +02:00
|
|
|
return v
|
|
|
|
|
|
|
|
with tf.variable_creator_scope(track_var_creation):
|
2021-04-27 10:44:54 +02:00
|
|
|
if issubclass(model_cls, tf.keras.Model):
|
|
|
|
instance = model_cls(
|
|
|
|
input_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=num_outputs,
|
|
|
|
name=name,
|
|
|
|
**customized_model_kwargs,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Try calling with kwargs first (custom ModelV2 should
|
|
|
|
# accept these as kwargs, not get them from
|
|
|
|
# config["custom_model_config"] anymore).
|
|
|
|
try:
|
|
|
|
instance = model_cls(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
**customized_model_kwargs,
|
|
|
|
)
|
|
|
|
except TypeError as e:
|
|
|
|
# Keyword error: Try old way w/o kwargs.
|
|
|
|
if "__init__() got an unexpected " in e.args[0]:
|
|
|
|
instance = model_cls(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
**model_kwargs,
|
|
|
|
)
|
|
|
|
logger.warning(
|
|
|
|
"Custom ModelV2 should accept all custom "
|
|
|
|
"options as **kwargs, instead of expecting"
|
2022-01-29 18:41:57 -08:00
|
|
|
" them in config['custom_model_config']!"
|
|
|
|
)
|
2021-04-27 10:44:54 +02:00
|
|
|
# Other error -> re-raise.
|
|
|
|
else:
|
|
|
|
raise e
|
2021-01-14 14:44:33 +01:00
|
|
|
|
|
|
|
# User still registered TFModelV2's variables: Check, whether
|
|
|
|
# ok.
|
2021-04-27 10:44:54 +02:00
|
|
|
registered = []
|
|
|
|
if not isinstance(instance, tf.keras.Model):
|
|
|
|
registered = set(instance.var_list)
|
2021-01-14 14:44:33 +01:00
|
|
|
if len(registered) > 0:
|
|
|
|
not_registered = set()
|
|
|
|
for var in created:
|
|
|
|
if var not in registered:
|
|
|
|
not_registered.add(var)
|
|
|
|
if not_registered:
|
|
|
|
raise ValueError(
|
|
|
|
"It looks like you are still using "
|
|
|
|
"`{}.register_variables()` to register your "
|
|
|
|
"model's weights. This is no longer required, but "
|
|
|
|
"if you are still calling this method at least "
|
|
|
|
"once, you must make sure to register all created "
|
|
|
|
"variables properly. The missing variables are {},"
|
|
|
|
" and you only registered {}. "
|
|
|
|
"Did you forget to call `register_variables()` on "
|
|
|
|
"some of the variables in question?".format(
|
2022-01-29 18:41:57 -08:00
|
|
|
instance, not_registered, registered
|
|
|
|
)
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
elif framework == "torch":
|
2021-01-01 14:06:23 -05:00
|
|
|
# Try wrapping custom model with LSTM/attention, if required.
|
2022-01-29 18:41:57 -08:00
|
|
|
if model_config.get("use_lstm") or model_config.get("use_attention"):
|
|
|
|
from ray.rllib.models.torch.attention_net import AttentionWrapper
|
|
|
|
from ray.rllib.models.torch.recurrent_net import LSTMWrapper
|
2021-01-01 14:06:23 -05:00
|
|
|
|
|
|
|
wrapped_cls = model_cls
|
|
|
|
forward = wrapped_cls.forward
|
|
|
|
model_cls = ModelCatalog._wrap_if_needed(
|
2022-01-29 18:41:57 -08:00
|
|
|
wrapped_cls,
|
|
|
|
LSTMWrapper
|
|
|
|
if model_config.get("use_lstm")
|
|
|
|
else AttentionWrapper,
|
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
model_cls._wrapped_forward = forward
|
|
|
|
|
2020-09-06 10:58:00 +02:00
|
|
|
# PyTorch automatically tracks nn.Modules inside the parent
|
|
|
|
# nn.Module's constructor.
|
|
|
|
# Try calling with kwargs first (custom ModelV2 should
|
|
|
|
# accept these as kwargs, not get them from
|
|
|
|
# config["custom_model_config"] anymore).
|
|
|
|
try:
|
2022-01-29 18:41:57 -08:00
|
|
|
instance = model_cls(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
**customized_model_kwargs,
|
|
|
|
)
|
2020-09-06 10:58:00 +02:00
|
|
|
except TypeError as e:
|
|
|
|
# Keyword error: Try old way w/o kwargs.
|
|
|
|
if "__init__() got an unexpected " in e.args[0]:
|
2022-01-29 18:41:57 -08:00
|
|
|
instance = model_cls(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
**model_kwargs,
|
|
|
|
)
|
2020-09-06 10:58:00 +02:00
|
|
|
logger.warning(
|
|
|
|
"Custom ModelV2 should accept all custom "
|
|
|
|
"options as **kwargs, instead of expecting"
|
2022-01-29 18:41:57 -08:00
|
|
|
" them in config['custom_model_config']!"
|
|
|
|
)
|
2020-09-06 10:58:00 +02:00
|
|
|
# Other error -> re-raise.
|
|
|
|
else:
|
|
|
|
raise e
|
2020-12-03 15:51:30 +01:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"`framework` must be 'tf2|tf|tfe|torch', but is "
|
2022-01-29 18:41:57 -08:00
|
|
|
"{}!".format(framework)
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
|
2020-09-06 10:58:00 +02:00
|
|
|
return instance
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-12-03 15:51:30 +01:00
|
|
|
# Find a default TFModelV2 and wrap with model_interface.
|
2020-07-11 22:06:35 +02:00
|
|
|
if framework in ["tf", "tfe", "tf2"]:
|
2019-08-23 02:21:11 -04:00
|
|
|
v2_class = None
|
2020-06-05 15:40:30 +02:00
|
|
|
# Try to get a default v2 model.
|
2019-08-23 02:21:11 -04:00
|
|
|
if not model_config.get("custom_model"):
|
2020-04-15 13:25:16 +02:00
|
|
|
v2_class = default_model or ModelCatalog._get_v2_model_class(
|
2022-01-29 18:41:57 -08:00
|
|
|
obs_space, model_config, framework=framework
|
|
|
|
)
|
2020-06-05 15:40:30 +02:00
|
|
|
|
2020-09-06 10:58:00 +02:00
|
|
|
if not v2_class:
|
|
|
|
raise ValueError("ModelV2 class could not be determined!")
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
if model_config.get("use_lstm") or model_config.get("use_attention"):
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.tf.attention_net import (
|
|
|
|
AttentionWrapper,
|
|
|
|
Keras_AttentionWrapper,
|
|
|
|
)
|
|
|
|
from ray.rllib.models.tf.recurrent_net import (
|
|
|
|
LSTMWrapper,
|
|
|
|
Keras_LSTMWrapper,
|
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2020-06-05 15:40:30 +02:00
|
|
|
wrapped_cls = v2_class
|
2021-01-01 14:06:23 -05:00
|
|
|
if model_config.get("use_lstm"):
|
2021-04-30 19:26:30 +02:00
|
|
|
if issubclass(wrapped_cls, tf.keras.Model):
|
|
|
|
v2_class = Keras_LSTMWrapper
|
|
|
|
model_config["wrapped_cls"] = wrapped_cls
|
|
|
|
else:
|
|
|
|
v2_class = ModelCatalog._wrap_if_needed(
|
2022-01-29 18:41:57 -08:00
|
|
|
wrapped_cls, LSTMWrapper
|
|
|
|
)
|
2021-04-30 19:26:30 +02:00
|
|
|
v2_class._wrapped_forward = wrapped_cls.forward
|
2021-01-01 14:06:23 -05:00
|
|
|
else:
|
2021-04-30 19:26:30 +02:00
|
|
|
if issubclass(wrapped_cls, tf.keras.Model):
|
|
|
|
v2_class = Keras_AttentionWrapper
|
|
|
|
model_config["wrapped_cls"] = wrapped_cls
|
|
|
|
else:
|
|
|
|
v2_class = ModelCatalog._wrap_if_needed(
|
2022-01-29 18:41:57 -08:00
|
|
|
wrapped_cls, AttentionWrapper
|
|
|
|
)
|
2021-04-30 19:26:30 +02:00
|
|
|
v2_class._wrapped_forward = wrapped_cls.forward
|
2020-06-05 15:40:30 +02:00
|
|
|
|
|
|
|
# Wrap in the requested interface.
|
2019-08-23 02:21:11 -04:00
|
|
|
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
|
2021-04-27 10:44:54 +02:00
|
|
|
|
|
|
|
if issubclass(wrapper, tf.keras.Model):
|
2021-04-30 19:26:30 +02:00
|
|
|
model = wrapper(
|
2021-04-27 10:44:54 +02:00
|
|
|
input_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=num_outputs,
|
|
|
|
name=name,
|
|
|
|
**dict(model_kwargs, **model_config),
|
|
|
|
)
|
2021-04-30 19:26:30 +02:00
|
|
|
return model
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
return wrapper(
|
|
|
|
obs_space, action_space, num_outputs, model_config, name, **model_kwargs
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
|
|
|
|
# Find a default TorchModelV2 and wrap with model_interface.
|
2019-07-25 11:02:53 -07:00
|
|
|
elif framework == "torch":
|
2021-01-01 14:06:23 -05:00
|
|
|
# Try to get a default v2 model.
|
|
|
|
if not model_config.get("custom_model"):
|
|
|
|
v2_class = default_model or ModelCatalog._get_v2_model_class(
|
2022-01-29 18:41:57 -08:00
|
|
|
obs_space, model_config, framework=framework
|
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
|
|
|
|
if not v2_class:
|
|
|
|
raise ValueError("ModelV2 class could not be determined!")
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
if model_config.get("use_lstm") or model_config.get("use_attention"):
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.torch.attention_net import AttentionWrapper
|
2021-01-01 14:06:23 -05:00
|
|
|
from ray.rllib.models.torch.recurrent_net import LSTMWrapper
|
|
|
|
|
2020-06-05 15:40:30 +02:00
|
|
|
wrapped_cls = v2_class
|
|
|
|
forward = wrapped_cls.forward
|
2021-01-01 14:06:23 -05:00
|
|
|
if model_config.get("use_lstm"):
|
2022-01-29 18:41:57 -08:00
|
|
|
v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper)
|
2021-01-01 14:06:23 -05:00
|
|
|
else:
|
|
|
|
v2_class = ModelCatalog._wrap_if_needed(
|
2022-01-29 18:41:57 -08:00
|
|
|
wrapped_cls, AttentionWrapper
|
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2020-06-05 15:40:30 +02:00
|
|
|
v2_class._wrapped_forward = forward
|
2021-01-01 14:06:23 -05:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
# Wrap in the requested interface.
|
2020-04-06 20:56:16 +02:00
|
|
|
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
|
2022-01-29 18:41:57 -08:00
|
|
|
return wrapper(
|
|
|
|
obs_space, action_space, num_outputs, model_config, name, **model_kwargs
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
|
|
|
|
# Find a default JAXModelV2 and wrap with model_interface.
|
|
|
|
elif framework == "jax":
|
2022-01-29 18:41:57 -08:00
|
|
|
v2_class = default_model or ModelCatalog._get_v2_model_class(
|
|
|
|
obs_space, model_config, framework=framework
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
# Wrap in the requested interface.
|
|
|
|
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
|
2022-01-29 18:41:57 -08:00
|
|
|
return wrapper(
|
|
|
|
obs_space, action_space, num_outputs, model_config, name, **model_kwargs
|
|
|
|
)
|
2019-07-25 11:02:53 -07:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
2020-09-06 10:58:00 +02:00
|
|
|
"`framework` must be 'tf2|tf|tfe|torch', but is "
|
2022-01-29 18:41:57 -08:00
|
|
|
"{}!".format(framework)
|
|
|
|
)
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-27 02:08:16 -07:00
|
|
|
@staticmethod
|
|
|
|
@DeveloperAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
def get_preprocessor(env: gym.Env, options: Optional[dict] = None) -> Preprocessor:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Returns a suitable preprocessor for the given env.
|
|
|
|
|
|
|
|
This is a wrapper for get_preprocessor_for_space().
|
|
|
|
"""
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
return ModelCatalog.get_preprocessor_for_space(env.observation_space, options)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@DeveloperAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
def get_preprocessor_for_space(
|
|
|
|
observation_space: gym.Space, options: dict = None
|
|
|
|
) -> Preprocessor:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Returns a suitable preprocessor for the given observation space.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
observation_space (Space): The input observation space.
|
|
|
|
options (dict): Options to pass to the preprocessor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
preprocessor (Preprocessor): Preprocessor for the observations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
options = options or MODEL_DEFAULTS
|
|
|
|
for k in options.keys():
|
|
|
|
if k not in MODEL_DEFAULTS:
|
2022-01-29 18:41:57 -08:00
|
|
|
raise Exception(
|
|
|
|
"Unknown config key `{}`, all keys: {}".format(
|
|
|
|
k, list(MODEL_DEFAULTS)
|
|
|
|
)
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
if options.get("custom_preprocessor"):
|
|
|
|
preprocessor = options["custom_preprocessor"]
|
|
|
|
logger.info("Using custom preprocessor {}".format(preprocessor))
|
2020-01-18 23:30:09 -08:00
|
|
|
logger.warning(
|
|
|
|
"DeprecationWarning: Custom preprocessors are deprecated, "
|
|
|
|
"since they sometimes conflict with the built-in "
|
|
|
|
"preprocessors for handling complex observation spaces. "
|
|
|
|
"Please use wrapper classes around your environment "
|
2022-01-29 18:41:57 -08:00
|
|
|
"instead of preprocessors."
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
2022-01-29 18:41:57 -08:00
|
|
|
observation_space, options
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
else:
|
|
|
|
cls = get_preprocessor(observation_space)
|
|
|
|
prep = cls(observation_space, options)
|
|
|
|
|
2021-09-09 08:10:42 +02:00
|
|
|
if prep is not None:
|
2022-01-29 18:41:57 -08:00
|
|
|
logger.debug(
|
|
|
|
"Created preprocessor {}: {} -> {}".format(
|
|
|
|
prep, observation_space, prep.shape
|
|
|
|
)
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
return prep
|
|
|
|
|
|
|
|
@staticmethod
|
2021-09-09 08:10:42 +02:00
|
|
|
@Deprecated(error=False)
|
2022-01-29 18:41:57 -08:00
|
|
|
def register_custom_preprocessor(
|
|
|
|
preprocessor_name: str, preprocessor_class: type
|
|
|
|
) -> None:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Register a custom preprocessor class by name.
|
|
|
|
|
|
|
|
The preprocessor can be later used by specifying
|
|
|
|
{"custom_preprocessor": preprocesor_name} in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
preprocessor_name (str): Name to register the preprocessor under.
|
|
|
|
preprocessor_class (type): Python class of the preprocessor.
|
|
|
|
"""
|
2022-01-29 18:41:57 -08:00
|
|
|
_global_registry.register(
|
|
|
|
RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
2020-07-24 12:01:46 -07:00
|
|
|
def register_custom_model(model_name: str, model_class: type) -> None:
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Register a custom model class by name.
|
|
|
|
|
|
|
|
The model can be later used by specifying {"custom_model": model_name}
|
|
|
|
in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_name (str): Name to register the model under.
|
|
|
|
model_class (type): Python class of the model.
|
|
|
|
"""
|
2021-06-04 11:07:59 +03:00
|
|
|
if tf is not None:
|
|
|
|
if issubclass(model_class, tf.keras.Model):
|
|
|
|
deprecation_warning(old="register_custom_model", error=False)
|
2019-07-27 02:08:16 -07:00
|
|
|
_global_registry.register(RLLIB_MODEL, model_name, model_class)
|
|
|
|
|
2019-08-06 18:13:16 +00:00
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
2022-01-29 18:41:57 -08:00
|
|
|
def register_custom_action_dist(
|
|
|
|
action_dist_name: str, action_dist_class: type
|
|
|
|
) -> None:
|
2019-08-06 18:13:16 +00:00
|
|
|
"""Register a custom action distribution class by name.
|
|
|
|
|
|
|
|
The model can be later used by specifying
|
|
|
|
{"custom_action_dist": action_dist_name} in the model config.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_name (str): Name to register the action distribution under.
|
|
|
|
model_class (type): Python class of the action distribution.
|
|
|
|
"""
|
2022-01-29 18:41:57 -08:00
|
|
|
_global_registry.register(
|
|
|
|
RLLIB_ACTION_DIST, action_dist_name, action_dist_class
|
|
|
|
)
|
2019-08-06 18:13:16 +00:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
@staticmethod
|
2020-07-24 12:01:46 -07:00
|
|
|
def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
|
2019-07-03 15:59:47 -07:00
|
|
|
if not model_interface or issubclass(model_cls, model_interface):
|
|
|
|
return model_cls
|
|
|
|
|
2021-04-27 10:44:54 +02:00
|
|
|
assert issubclass(model_cls, ModelV2), model_cls
|
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
class wrapper(model_interface, model_cls):
|
|
|
|
pass
|
|
|
|
|
|
|
|
name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
|
|
|
|
wrapper.__name__ = name
|
|
|
|
wrapper.__qualname__ = name
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
2020-06-16 09:01:20 +02:00
|
|
|
@staticmethod
|
2022-01-29 18:41:57 -08:00
|
|
|
def _get_v2_model_class(
|
|
|
|
input_space: gym.Space, model_config: ModelConfigDict, framework: str = "tf"
|
|
|
|
) -> Type[ModelV2]:
|
2020-12-03 15:51:30 +01:00
|
|
|
|
|
|
|
VisionNet = None
|
2021-02-02 13:05:58 +01:00
|
|
|
ComplexNet = None
|
2021-04-27 10:44:54 +02:00
|
|
|
Keras_FCNet = None
|
2021-04-30 19:26:30 +02:00
|
|
|
Keras_VisionNet = None
|
2020-12-03 15:51:30 +01:00
|
|
|
|
|
|
|
if framework in ["tf2", "tf", "tfe"]:
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.tf.fcnet import (
|
|
|
|
FullyConnectedNetwork as FCNet,
|
|
|
|
Keras_FullyConnectedNetwork as Keras_FCNet,
|
|
|
|
)
|
|
|
|
from ray.rllib.models.tf.visionnet import (
|
|
|
|
VisionNetwork as VisionNet,
|
|
|
|
Keras_VisionNetwork as Keras_VisionNet,
|
|
|
|
)
|
|
|
|
from ray.rllib.models.tf.complex_input_net import (
|
|
|
|
ComplexInputNetwork as ComplexNet,
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
elif framework == "torch":
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as FCNet
|
|
|
|
from ray.rllib.models.torch.visionnet import VisionNetwork as VisionNet
|
|
|
|
from ray.rllib.models.torch.complex_input_net import (
|
|
|
|
ComplexInputNetwork as ComplexNet,
|
|
|
|
)
|
2020-12-03 15:51:30 +01:00
|
|
|
elif framework == "jax":
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.jax.fcnet import FullyConnectedNetwork as FCNet
|
2020-06-16 09:01:20 +02:00
|
|
|
else:
|
2020-12-03 15:51:30 +01:00
|
|
|
raise ValueError(
|
|
|
|
"framework={} not supported in `ModelCatalog._get_v2_model_"
|
2022-01-29 18:41:57 -08:00
|
|
|
"class`!".format(framework)
|
|
|
|
)
|
2020-06-16 09:01:20 +02:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
orig_space = (
|
|
|
|
input_space
|
|
|
|
if not hasattr(input_space, "original_space")
|
|
|
|
else input_space.original_space
|
|
|
|
)
|
2021-02-02 13:05:58 +01:00
|
|
|
|
2021-12-13 12:04:23 +01:00
|
|
|
# `input_space` is 3D Box -> VisionNet.
|
|
|
|
if isinstance(input_space, Box) and len(input_space.shape) == 3:
|
|
|
|
if framework == "jax":
|
|
|
|
raise NotImplementedError("No non-FC default net for JAX yet!")
|
2022-01-29 18:41:57 -08:00
|
|
|
elif model_config.get("_use_default_native_models") and Keras_VisionNet:
|
2021-12-13 12:04:23 +01:00
|
|
|
return Keras_VisionNet
|
|
|
|
return VisionNet
|
|
|
|
# `input_space` is 1D Box -> FCNet.
|
2022-01-29 18:41:57 -08:00
|
|
|
elif (
|
|
|
|
isinstance(input_space, Box)
|
|
|
|
and len(input_space.shape) == 1
|
|
|
|
and (
|
|
|
|
not isinstance(orig_space, (Dict, Tuple))
|
|
|
|
or not any(
|
2021-12-13 12:04:23 +01:00
|
|
|
isinstance(s, Box) and len(s.shape) >= 2
|
2022-01-29 18:41:57 -08:00
|
|
|
for s in tree.flatten(orig_space.spaces)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
):
|
2021-04-30 19:26:30 +02:00
|
|
|
# Keras native requested AND no auto-rnn-wrapping.
|
|
|
|
if model_config.get("_use_default_native_models") and Keras_FCNet:
|
2021-04-27 10:44:54 +02:00
|
|
|
return Keras_FCNet
|
|
|
|
# Classic ModelV2 FCNet.
|
|
|
|
else:
|
|
|
|
return FCNet
|
2021-12-13 12:04:23 +01:00
|
|
|
# Complex (Dict, Tuple, 2D Box (flatten), Discrete, MultiDiscrete).
|
|
|
|
else:
|
|
|
|
if framework == "jax":
|
|
|
|
raise NotImplementedError("No non-FC default net for JAX yet!")
|
|
|
|
return ComplexNet
|
2020-10-12 22:50:43 +02:00
|
|
|
|
|
|
|
@staticmethod
|
2022-01-29 18:41:57 -08:00
|
|
|
def _get_multi_action_distribution(dist_class, action_space, config, framework):
|
2020-10-12 22:50:43 +02:00
|
|
|
# In case the custom distribution is a child of MultiActionDistr.
|
|
|
|
# If users want to completely ignore the suggested child
|
|
|
|
# distributions, they should simply do so in their custom class'
|
|
|
|
# constructor.
|
2022-01-29 18:41:57 -08:00
|
|
|
if issubclass(
|
|
|
|
dist_class, (MultiActionDistribution, TorchMultiActionDistribution)
|
|
|
|
):
|
2020-10-12 22:50:43 +02:00
|
|
|
flat_action_space = flatten_space(action_space)
|
|
|
|
child_dists_and_in_lens = tree.map_structure(
|
2022-01-29 18:41:57 -08:00
|
|
|
lambda s: ModelCatalog.get_action_dist(s, config, framework=framework),
|
|
|
|
flat_action_space,
|
|
|
|
)
|
2020-10-12 22:50:43 +02:00
|
|
|
child_dists = [e[0] for e in child_dists_and_in_lens]
|
|
|
|
input_lens = [int(e[1]) for e in child_dists_and_in_lens]
|
2022-01-29 18:41:57 -08:00
|
|
|
return (
|
|
|
|
partial(
|
|
|
|
dist_class,
|
|
|
|
action_space=action_space,
|
|
|
|
child_distributions=child_dists,
|
|
|
|
input_lens=input_lens,
|
|
|
|
),
|
|
|
|
int(sum(input_lens)),
|
|
|
|
)
|
|
|
|
return dist_class, dist_class.required_model_output_shape(action_space, config)
|
2021-01-01 14:06:23 -05:00
|
|
|
|
|
|
|
@staticmethod
|
2022-01-29 18:41:57 -08:00
|
|
|
def _validate_config(
|
|
|
|
config: ModelConfigDict, action_space: gym.spaces.Space, framework: str
|
|
|
|
) -> None:
|
2021-01-01 14:06:23 -05:00
|
|
|
"""Validates a given model config dict.
|
|
|
|
|
|
|
|
Args:
|
2022-01-05 11:29:44 +01:00
|
|
|
config: The "model" sub-config dict
|
2021-01-01 14:06:23 -05:00
|
|
|
within the Trainer's config dict.
|
2022-01-05 11:29:44 +01:00
|
|
|
action_space: The action space of the model, whose config are
|
|
|
|
validated.
|
|
|
|
framework: One of "jax", "tf2", "tf", "tfe", or "torch".
|
2021-01-01 14:06:23 -05:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If something is wrong with the given config.
|
|
|
|
"""
|
2021-09-09 08:10:42 +02:00
|
|
|
# Soft-deprecate custom preprocessors.
|
|
|
|
if config.get("custom_preprocessor") is not None:
|
|
|
|
deprecation_warning(
|
|
|
|
old="model.custom_preprocessor",
|
|
|
|
new="gym.ObservationWrapper around your env or handle complex "
|
|
|
|
"inputs inside your Model",
|
|
|
|
error=False,
|
|
|
|
)
|
|
|
|
|
2021-01-01 14:06:23 -05:00
|
|
|
if config.get("use_attention") and config.get("use_lstm"):
|
2022-01-29 18:41:57 -08:00
|
|
|
raise ValueError(
|
2022-03-15 09:34:21 -07:00
|
|
|
"Only one of `use_lstm` or `use_attention` may be set to True!"
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-05 11:29:44 +01:00
|
|
|
|
|
|
|
# For complex action spaces, only allow prev action inputs to
|
|
|
|
# LSTMs and attention nets iff `_disable_action_flattening=True`.
|
|
|
|
# TODO: `_disable_action_flattening=True` will be the default in
|
|
|
|
# the future.
|
2022-01-29 18:41:57 -08:00
|
|
|
if (
|
|
|
|
(
|
|
|
|
config.get("lstm_use_prev_action")
|
|
|
|
or config.get("attention_use_n_prev_actions", 0) > 0
|
|
|
|
)
|
|
|
|
and not config.get("_disable_action_flattening")
|
|
|
|
and isinstance(action_space, (Tuple, Dict))
|
|
|
|
):
|
2022-01-05 11:29:44 +01:00
|
|
|
raise ValueError(
|
|
|
|
"For your complex action space (Tuple|Dict) and your model's "
|
|
|
|
"`prev-actions` setup of your model, you must set "
|
2022-01-29 18:41:57 -08:00
|
|
|
"`_disable_action_flattening=True` in your main config dict!"
|
|
|
|
)
|
2022-01-05 11:29:44 +01:00
|
|
|
|
2021-01-01 14:06:23 -05:00
|
|
|
if framework == "jax":
|
|
|
|
if config.get("use_attention"):
|
2022-01-29 18:41:57 -08:00
|
|
|
raise ValueError(
|
2022-03-15 09:34:21 -07:00
|
|
|
"`use_attention` not available for framework=jax so far!"
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-01-01 14:06:23 -05:00
|
|
|
elif config.get("use_lstm"):
|
2022-03-15 09:34:21 -07:00
|
|
|
raise ValueError("`use_lstm` not available for framework=jax so far!")
|