from functools import partial import gym from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple import logging import numpy as np import tree # pip install dm_tree from typing import List, Optional, Type, Union from ray.tune.registry import ( RLLIB_MODEL, RLLIB_PREPROCESSOR, RLLIB_ACTION_DIST, _global_registry, ) from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor from ray.rllib.models.tf.tf_action_dist import ( Categorical, Deterministic, DiagGaussian, Dirichlet, MultiActionDistribution, MultiCategorical, ) from ray.rllib.models.torch.torch_action_dist import ( TorchCategorical, TorchDeterministic, TorchDiagGaussian, TorchMultiActionDistribution, TorchMultiCategorical, ) from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.deprecation import ( Deprecated, DEPRECATED_VALUE, deprecation_warning, ) from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import flatten_space from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() logger = logging.getLogger(__name__) # fmt: off # __sphinx_doc_begin__ MODEL_DEFAULTS: ModelConfigDict = { # Experimental flag. # If True, try to use a native (tf.keras.Model or torch.Module) default # model instead of our built-in ModelV2 defaults. # If False (default), use "classic" ModelV2 default models. # Note that this currently only works for: # 1) framework != torch AND # 2) fully connected and CNN default networks as well as # auto-wrapped LSTM- and attention nets. "_use_default_native_models": False, # Experimental flag. # If True, user specified no preprocessor to be created # (via config._disable_preprocessor_api=True). If True, observations # will arrive in model as they are returned by the env. "_disable_preprocessor_api": False, # Experimental flag. # If True, RLlib will no longer flatten the policy-computed actions into # a single tensor (for storage in SampleCollectors/output files/etc..), # but leave (possibly nested) actions as-is. Disabling flattening affects: # - SampleCollectors: Have to store possibly nested action structs. # - Models that have the previous action(s) as part of their input. # - Algorithms reading from offline files (incl. action information). "_disable_action_flattening": False, # === Built-in options === # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py # These are used if no custom model is specified and the input space is 1D. # Number of hidden layers to be used. "fcnet_hiddens": [256, 256], # Activation function descriptor. # Supported values are: "tanh", "relu", "swish" (or "silu"), # "linear" (or None). "fcnet_activation": "tanh", # VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py # These are used if no custom model is specified and the input space is 2D. # Filter config: List of [out_channels, kernel, stride] for each filter. # Example: # Use None for making RLlib try to find a default filter setup given the # observation space. "conv_filters": None, # Activation function descriptor. # Supported values are: "tanh", "relu", "swish" (or "silu"), # "linear" (or None). "conv_activation": "relu", # Some default models support a final FC stack of n Dense layers with given # activation: # - Complex observation spaces: Image components are fed through # VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then # everything is concated and pushed through this final FC stack. # - VisionNets (CNNs), e.g. after the CNN stack, there may be # additional Dense layers. # - FullyConnectedNetworks will have this additional FCStack as well # (that's why it's empty by default). "post_fcnet_hiddens": [], "post_fcnet_activation": "relu", # For DiagGaussian action distributions, make the second half of the model # outputs floating bias variables instead of state-dependent. This only # has an effect is using the default fully connected net. "free_log_std": False, # Whether to skip the final linear layer used to resize the hidden layer # outputs to size `num_outputs`. If True, then the last hidden layer # should already match num_outputs. "no_final_linear": False, # Whether layers should be shared for the value function. "vf_share_layers": True, # == LSTM == # Whether to wrap the model with an LSTM. "use_lstm": False, # Max seq len for training the LSTM, defaults to 20. "max_seq_len": 20, # Size of the LSTM cell. "lstm_cell_size": 256, # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete). "lstm_use_prev_action": False, # Whether to feed r_{t-1} to LSTM. "lstm_use_prev_reward": False, # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..). "_time_major": False, # == Attention Nets (experimental: torch-version is untested) == # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the # wrapper Model around the default Model. "use_attention": False, # The number of transformer units within GTrXL. # A transformer unit in GTrXL consists of a) MultiHeadAttention module and # b) a position-wise MLP. "attention_num_transformer_units": 1, # The input and output size of each transformer unit. "attention_dim": 64, # The number of attention heads within the MultiHeadAttention units. "attention_num_heads": 1, # The dim of a single head (within the MultiHeadAttention units). "attention_head_dim": 32, # The memory sizes for inference and training. "attention_memory_inference": 50, "attention_memory_training": 50, # The output dim of the position-wise MLP. "attention_position_wise_mlp_dim": 32, # The initial bias values for the 2 GRU gates within a transformer unit. "attention_init_gru_gate_bias": 2.0, # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete). "attention_use_n_prev_actions": 0, # Whether to feed r_{t-n:t-1} to GTrXL. "attention_use_n_prev_rewards": 0, # == Atari == # Set to True to enable 4x stacking behavior. "framestack": True, # Final resized frame dimension "dim": 84, # (deprecated) Converts ATARI frame to 1 Channel Grayscale image "grayscale": False, # (deprecated) Changes frame to range from [-1, 1] if true "zero_mean": True, # === Options for custom models === # Name of a custom model to use "custom_model": None, # Extra options to pass to the custom classes. These will be available to # the Model's constructor in the model_config field. Also, they will be # attempted to be passed as **kwargs to ModelV2 models. For an example, # see rllib/models/[tf|torch]/attention_net.py. "custom_model_config": {}, # Name of a custom action distribution to use. "custom_action_dist": None, # Custom preprocessors are deprecated. Please use a wrapper class around # your environment instead to preprocess observations. "custom_preprocessor": None, # Deprecated keys: # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead. "lstm_use_prev_action_reward": DEPRECATED_VALUE, } # __sphinx_doc_end__ # fmt: on @PublicAPI class ModelCatalog: """Registry of models, preprocessors, and action distributions for envs. Examples: >>> prep = ModelCatalog.get_preprocessor(env) >>> observation = prep.transform(raw_observation) >>> dist_class, dist_dim = ModelCatalog.get_action_dist( ... env.action_space, {}) >>> model = ModelCatalog.get_model_v2( ... obs_space, action_space, num_outputs, options) >>> dist = dist_class(model.outputs, model) >>> action = dist.sample() """ @staticmethod @DeveloperAPI def get_action_dist( action_space: gym.Space, config: ModelConfigDict, dist_type: Optional[Union[str, Type[ActionDistribution]]] = None, framework: str = "tf", **kwargs ) -> (type, int): """Returns a distribution class and size for the given action space. Args: action_space: Action space of the target gym env. config (Optional[dict]): Optional model config. dist_type (Optional[Union[str, Type[ActionDistribution]]]): Identifier of the action distribution (str) interpreted as a hint or the actual ActionDistribution class to use. framework: One of "tf2", "tf", "tfe", "torch", or "jax". kwargs: Optional kwargs to pass on to the Distribution's constructor. Returns: Tuple: - dist_class (ActionDistribution): Python class of the distribution. - dist_dim (int): The size of the input vector to the distribution. """ dist_cls = None config = config or MODEL_DEFAULTS # Custom distribution given. if config.get("custom_action_dist"): custom_action_config = config.copy() action_dist_name = custom_action_config.pop("custom_action_dist") logger.debug("Using custom action distribution {}".format(action_dist_name)) dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) return ModelCatalog._get_multi_action_distribution( dist_cls, action_space, custom_action_config, framework ) # Dist_type is given directly as a class. elif ( type(dist_type) is type and issubclass(dist_type, ActionDistribution) and dist_type not in (MultiActionDistribution, TorchMultiActionDistribution) ): dist_cls = dist_type # Box space -> DiagGaussian OR Deterministic. elif isinstance(action_space, Box): if action_space.dtype.name.startswith("int"): low_ = np.min(action_space.low) high_ = np.max(action_space.high) dist_cls = ( TorchMultiCategorical if framework == "torch" else MultiCategorical ) num_cats = int(np.product(action_space.shape)) return ( partial( dist_cls, input_lens=[high_ - low_ + 1 for _ in range(num_cats)], action_space=action_space, ), num_cats * (high_ - low_ + 1), ) else: if len(action_space.shape) > 1: raise UnsupportedSpaceException( "Action space has multiple dimensions " "{}. ".format(action_space.shape) + "Consider reshaping this into a single dimension, " "using a custom action distribution, " "using a Tuple action space, or the multi-agent API." ) # TODO(sven): Check for bounds and return SquashedNormal, etc.. if dist_type is None: return ( partial( TorchDiagGaussian if framework == "torch" else DiagGaussian, action_space=action_space, ), DiagGaussian.required_model_output_shape(action_space, config), ) elif dist_type == "deterministic": dist_cls = ( TorchDeterministic if framework == "torch" else Deterministic ) # Discrete Space -> Categorical. elif isinstance(action_space, Discrete): if framework == "torch": dist_cls = TorchCategorical elif framework == "jax": from ray.rllib.models.jax.jax_action_dist import JAXCategorical dist_cls = JAXCategorical else: dist_cls = Categorical # Tuple/Dict Spaces -> MultiAction. elif ( dist_type in ( MultiActionDistribution, TorchMultiActionDistribution, ) or isinstance(action_space, (Tuple, Dict)) ): return ModelCatalog._get_multi_action_distribution( ( MultiActionDistribution if framework == "tf" else TorchMultiActionDistribution ), action_space, config, framework, ) # Simplex -> Dirichlet. elif isinstance(action_space, Simplex): if framework == "torch": # TODO(sven): implement raise NotImplementedError( "Simplex action spaces not supported for torch." ) dist_cls = Dirichlet # MultiDiscrete -> MultiCategorical. elif isinstance(action_space, MultiDiscrete): dist_cls = ( TorchMultiCategorical if framework == "torch" else MultiCategorical ) return partial(dist_cls, input_lens=action_space.nvec), int( sum(action_space.nvec) ) # Unknown type -> Error. else: raise NotImplementedError( "Unsupported args: {} {}".format(action_space, dist_type) ) return dist_cls, dist_cls.required_model_output_shape(action_space, config) @staticmethod @DeveloperAPI def get_action_shape( action_space: gym.Space, framework: str = "tf" ) -> (np.dtype, List[int]): """Returns action tensor dtype and shape for the action space. Args: action_space: Action space of the target gym env. framework: The framework identifier. One of "tf" or "torch". Returns: (dtype, shape): Dtype and shape of the actions tensor. """ dl_lib = torch if framework == "torch" else tf if isinstance(action_space, Discrete): return action_space.dtype, (None,) elif isinstance(action_space, (Box, Simplex)): if np.issubdtype(action_space.dtype, np.floating): return dl_lib.float32, (None,) + action_space.shape elif np.issubdtype(action_space.dtype, np.integer): return dl_lib.int32, (None,) + action_space.shape else: raise ValueError("RLlib doesn't support non int or float box spaces") elif isinstance(action_space, MultiDiscrete): return action_space.dtype, (None,) + action_space.shape elif isinstance(action_space, (Tuple, Dict)): flat_action_space = flatten_space(action_space) size = 0 all_discrete = True for i in range(len(flat_action_space)): if isinstance(flat_action_space[i], Discrete): size += 1 else: all_discrete = False size += np.product(flat_action_space[i].shape) size = int(size) return dl_lib.int32 if all_discrete else dl_lib.float32, (None, size) else: raise NotImplementedError( "Action space {} not supported".format(action_space) ) @staticmethod @DeveloperAPI def get_action_placeholder( action_space: gym.Space, name: str = "action" ) -> TensorType: """Returns an action placeholder consistent with the action space Args: action_space: Action space of the target gym env. name: An optional string to name the placeholder by. Default: "action". Returns: action_placeholder: A placeholder for the actions """ dtype, shape = ModelCatalog.get_action_shape(action_space, framework="tf") return tf1.placeholder(dtype, shape=shape, name=name) @staticmethod @DeveloperAPI def get_model_v2( obs_space: gym.Space, action_space: gym.Space, num_outputs: int, model_config: ModelConfigDict, framework: str = "tf", name: str = "default_model", model_interface: type = None, default_model: type = None, **model_kwargs ) -> ModelV2: """Returns a suitable model compatible with given spaces and output. Args: obs_space: Observation space of the target gym env. This may have an `original_space` attribute that specifies how to unflatten the tensor into a ragged tensor. action_space: Action space of the target gym env. num_outputs: The size of the output vector of the model. model_config: The "model" sub-config dict within the Trainer's config dict. framework: One of "tf2", "tf", "tfe", "torch", or "jax". name: Name (scope) for the model. model_interface: Interface required for the model default_model: Override the default class for the model. This only has an effect when not using a custom model model_kwargs: Args to pass to the ModelV2 constructor Returns: model (ModelV2): Model to use for the policy. """ # Validate the given config dict. ModelCatalog._validate_config( config=model_config, action_space=action_space, framework=framework ) if model_config.get("custom_model"): # Allow model kwargs to be overridden / augmented by # custom_model_config. customized_model_kwargs = dict( model_kwargs, **model_config.get("custom_model_config", {}) ) if isinstance(model_config["custom_model"], type): model_cls = model_config["custom_model"] elif ( isinstance(model_config["custom_model"], str) and "." in model_config["custom_model"] ): return from_config( cls=model_config["custom_model"], obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=customized_model_kwargs, name=name, ) else: model_cls = _global_registry.get( RLLIB_MODEL, model_config["custom_model"] ) # Only allow ModelV2 or native keras Models. if not issubclass(model_cls, ModelV2): if framework not in ["tf", "tf2", "tfe"] or not issubclass( model_cls, tf.keras.Model ): raise ValueError( "`model_cls` must be a ModelV2 sub-class, but is" " {}!".format(model_cls) ) logger.info("Wrapping {} as {}".format(model_cls, model_interface)) model_cls = ModelCatalog._wrap_if_needed(model_cls, model_interface) if framework in ["tf2", "tf", "tfe"]: # Try wrapping custom model with LSTM/attention, if required. if model_config.get("use_lstm") or model_config.get("use_attention"): from ray.rllib.models.tf.attention_net import ( AttentionWrapper, Keras_AttentionWrapper, ) from ray.rllib.models.tf.recurrent_net import ( LSTMWrapper, Keras_LSTMWrapper, ) wrapped_cls = model_cls # Wrapped (custom) model is itself a keras Model -> # wrap with keras LSTM/GTrXL (attention) wrappers. if issubclass(wrapped_cls, tf.keras.Model): model_cls = ( Keras_LSTMWrapper if model_config.get("use_lstm") else Keras_AttentionWrapper ) model_config["wrapped_cls"] = wrapped_cls # Wrapped (custom) model is ModelV2 -> # wrap with ModelV2 LSTM/GTrXL (attention) wrappers. else: forward = wrapped_cls.forward model_cls = ModelCatalog._wrap_if_needed( wrapped_cls, LSTMWrapper if model_config.get("use_lstm") else AttentionWrapper, ) model_cls._wrapped_forward = forward # Obsolete: Track and warn if vars were created but not # registered. Only still do this, if users do register their # variables. If not (which they shouldn't), don't check here. created = set() def track_var_creation(next_creator, **kw): v = next_creator(**kw) created.add(v.ref()) return v with tf.variable_creator_scope(track_var_creation): if issubclass(model_cls, tf.keras.Model): instance = model_cls( input_space=obs_space, action_space=action_space, num_outputs=num_outputs, name=name, **customized_model_kwargs, ) else: # Try calling with kwargs first (custom ModelV2 should # accept these as kwargs, not get them from # config["custom_model_config"] anymore). try: instance = model_cls( obs_space, action_space, num_outputs, model_config, name, **customized_model_kwargs, ) except TypeError as e: # Keyword error: Try old way w/o kwargs. if "__init__() got an unexpected " in e.args[0]: instance = model_cls( obs_space, action_space, num_outputs, model_config, name, **model_kwargs, ) logger.warning( "Custom ModelV2 should accept all custom " "options as **kwargs, instead of expecting" " them in config['custom_model_config']!" ) # Other error -> re-raise. else: raise e # User still registered TFModelV2's variables: Check, whether # ok. registered = [] if not isinstance(instance, tf.keras.Model): registered = set(instance.var_list) if len(registered) > 0: not_registered = set() for var in created: if var not in registered: not_registered.add(var) if not_registered: raise ValueError( "It looks like you are still using " "`{}.register_variables()` to register your " "model's weights. This is no longer required, but " "if you are still calling this method at least " "once, you must make sure to register all created " "variables properly. The missing variables are {}," " and you only registered {}. " "Did you forget to call `register_variables()` on " "some of the variables in question?".format( instance, not_registered, registered ) ) elif framework == "torch": # Try wrapping custom model with LSTM/attention, if required. if model_config.get("use_lstm") or model_config.get("use_attention"): from ray.rllib.models.torch.attention_net import AttentionWrapper from ray.rllib.models.torch.recurrent_net import LSTMWrapper wrapped_cls = model_cls forward = wrapped_cls.forward model_cls = ModelCatalog._wrap_if_needed( wrapped_cls, LSTMWrapper if model_config.get("use_lstm") else AttentionWrapper, ) model_cls._wrapped_forward = forward # PyTorch automatically tracks nn.Modules inside the parent # nn.Module's constructor. # Try calling with kwargs first (custom ModelV2 should # accept these as kwargs, not get them from # config["custom_model_config"] anymore). try: instance = model_cls( obs_space, action_space, num_outputs, model_config, name, **customized_model_kwargs, ) except TypeError as e: # Keyword error: Try old way w/o kwargs. if "__init__() got an unexpected " in e.args[0]: instance = model_cls( obs_space, action_space, num_outputs, model_config, name, **model_kwargs, ) logger.warning( "Custom ModelV2 should accept all custom " "options as **kwargs, instead of expecting" " them in config['custom_model_config']!" ) # Other error -> re-raise. else: raise e else: raise NotImplementedError( "`framework` must be 'tf2|tf|tfe|torch', but is " "{}!".format(framework) ) return instance # Find a default TFModelV2 and wrap with model_interface. if framework in ["tf", "tfe", "tf2"]: v2_class = None # Try to get a default v2 model. if not model_config.get("custom_model"): v2_class = default_model or ModelCatalog._get_v2_model_class( obs_space, model_config, framework=framework ) if not v2_class: raise ValueError("ModelV2 class could not be determined!") if model_config.get("use_lstm") or model_config.get("use_attention"): from ray.rllib.models.tf.attention_net import ( AttentionWrapper, Keras_AttentionWrapper, ) from ray.rllib.models.tf.recurrent_net import ( LSTMWrapper, Keras_LSTMWrapper, ) wrapped_cls = v2_class if model_config.get("use_lstm"): if issubclass(wrapped_cls, tf.keras.Model): v2_class = Keras_LSTMWrapper model_config["wrapped_cls"] = wrapped_cls else: v2_class = ModelCatalog._wrap_if_needed( wrapped_cls, LSTMWrapper ) v2_class._wrapped_forward = wrapped_cls.forward else: if issubclass(wrapped_cls, tf.keras.Model): v2_class = Keras_AttentionWrapper model_config["wrapped_cls"] = wrapped_cls else: v2_class = ModelCatalog._wrap_if_needed( wrapped_cls, AttentionWrapper ) v2_class._wrapped_forward = wrapped_cls.forward # Wrap in the requested interface. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface) if issubclass(wrapper, tf.keras.Model): model = wrapper( input_space=obs_space, action_space=action_space, num_outputs=num_outputs, name=name, **dict(model_kwargs, **model_config), ) return model return wrapper( obs_space, action_space, num_outputs, model_config, name, **model_kwargs ) # Find a default TorchModelV2 and wrap with model_interface. elif framework == "torch": # Try to get a default v2 model. if not model_config.get("custom_model"): v2_class = default_model or ModelCatalog._get_v2_model_class( obs_space, model_config, framework=framework ) if not v2_class: raise ValueError("ModelV2 class could not be determined!") if model_config.get("use_lstm") or model_config.get("use_attention"): from ray.rllib.models.torch.attention_net import AttentionWrapper from ray.rllib.models.torch.recurrent_net import LSTMWrapper wrapped_cls = v2_class forward = wrapped_cls.forward if model_config.get("use_lstm"): v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper) else: v2_class = ModelCatalog._wrap_if_needed( wrapped_cls, AttentionWrapper ) v2_class._wrapped_forward = forward # Wrap in the requested interface. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface) return wrapper( obs_space, action_space, num_outputs, model_config, name, **model_kwargs ) # Find a default JAXModelV2 and wrap with model_interface. elif framework == "jax": v2_class = default_model or ModelCatalog._get_v2_model_class( obs_space, model_config, framework=framework ) # Wrap in the requested interface. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface) return wrapper( obs_space, action_space, num_outputs, model_config, name, **model_kwargs ) else: raise NotImplementedError( "`framework` must be 'tf2|tf|tfe|torch', but is " "{}!".format(framework) ) @staticmethod @DeveloperAPI def get_preprocessor(env: gym.Env, options: Optional[dict] = None) -> Preprocessor: """Returns a suitable preprocessor for the given env. This is a wrapper for get_preprocessor_for_space(). """ return ModelCatalog.get_preprocessor_for_space(env.observation_space, options) @staticmethod @DeveloperAPI def get_preprocessor_for_space( observation_space: gym.Space, options: dict = None ) -> Preprocessor: """Returns a suitable preprocessor for the given observation space. Args: observation_space: The input observation space. options: Options to pass to the preprocessor. Returns: preprocessor: Preprocessor for the observations. """ options = options or MODEL_DEFAULTS for k in options.keys(): if k not in MODEL_DEFAULTS: raise Exception( "Unknown config key `{}`, all keys: {}".format( k, list(MODEL_DEFAULTS) ) ) if options.get("custom_preprocessor"): preprocessor = options["custom_preprocessor"] logger.info("Using custom preprocessor {}".format(preprocessor)) logger.warning( "DeprecationWarning: Custom preprocessors are deprecated, " "since they sometimes conflict with the built-in " "preprocessors for handling complex observation spaces. " "Please use wrapper classes around your environment " "instead of preprocessors." ) prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( observation_space, options ) else: cls = get_preprocessor(observation_space) prep = cls(observation_space, options) if prep is not None: logger.debug( "Created preprocessor {}: {} -> {}".format( prep, observation_space, prep.shape ) ) return prep @staticmethod @Deprecated(error=False) def register_custom_preprocessor( preprocessor_name: str, preprocessor_class: type ) -> None: """Register a custom preprocessor class by name. The preprocessor can be later used by specifying {"custom_preprocessor": preprocesor_name} in the model config. Args: preprocessor_name: Name to register the preprocessor under. preprocessor_class: Python class of the preprocessor. """ _global_registry.register( RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class ) @staticmethod @PublicAPI def register_custom_model(model_name: str, model_class: type) -> None: """Register a custom model class by name. The model can be later used by specifying {"custom_model": model_name} in the model config. Args: model_name: Name to register the model under. model_class: Python class of the model. """ if tf is not None: if issubclass(model_class, tf.keras.Model): deprecation_warning(old="register_custom_model", error=False) _global_registry.register(RLLIB_MODEL, model_name, model_class) @staticmethod @PublicAPI def register_custom_action_dist( action_dist_name: str, action_dist_class: type ) -> None: """Register a custom action distribution class by name. The model can be later used by specifying {"custom_action_dist": action_dist_name} in the model config. Args: model_name: Name to register the action distribution under. model_class: Python class of the action distribution. """ _global_registry.register( RLLIB_ACTION_DIST, action_dist_name, action_dist_class ) @staticmethod def _wrap_if_needed(model_cls: type, model_interface: type) -> type: if not model_interface or issubclass(model_cls, model_interface): return model_cls assert issubclass(model_cls, ModelV2), model_cls class wrapper(model_interface, model_cls): pass name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__) wrapper.__name__ = name wrapper.__qualname__ = name return wrapper @staticmethod def _get_v2_model_class( input_space: gym.Space, model_config: ModelConfigDict, framework: str = "tf" ) -> Type[ModelV2]: VisionNet = None ComplexNet = None Keras_FCNet = None Keras_VisionNet = None if framework in ["tf2", "tf", "tfe"]: from ray.rllib.models.tf.fcnet import ( FullyConnectedNetwork as FCNet, Keras_FullyConnectedNetwork as Keras_FCNet, ) from ray.rllib.models.tf.visionnet import ( VisionNetwork as VisionNet, Keras_VisionNetwork as Keras_VisionNet, ) from ray.rllib.models.tf.complex_input_net import ( ComplexInputNetwork as ComplexNet, ) elif framework == "torch": from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as FCNet from ray.rllib.models.torch.visionnet import VisionNetwork as VisionNet from ray.rllib.models.torch.complex_input_net import ( ComplexInputNetwork as ComplexNet, ) elif framework == "jax": from ray.rllib.models.jax.fcnet import FullyConnectedNetwork as FCNet else: raise ValueError( "framework={} not supported in `ModelCatalog._get_v2_model_" "class`!".format(framework) ) orig_space = ( input_space if not hasattr(input_space, "original_space") else input_space.original_space ) # `input_space` is 3D Box -> VisionNet. if isinstance(input_space, Box) and len(input_space.shape) == 3: if framework == "jax": raise NotImplementedError("No non-FC default net for JAX yet!") elif model_config.get("_use_default_native_models") and Keras_VisionNet: return Keras_VisionNet return VisionNet # `input_space` is 1D Box -> FCNet. elif ( isinstance(input_space, Box) and len(input_space.shape) == 1 and ( not isinstance(orig_space, (Dict, Tuple)) or not any( isinstance(s, Box) and len(s.shape) >= 2 for s in flatten_space(orig_space) ) ) ): # Keras native requested AND no auto-rnn-wrapping. if model_config.get("_use_default_native_models") and Keras_FCNet: return Keras_FCNet # Classic ModelV2 FCNet. else: return FCNet # Complex (Dict, Tuple, 2D Box (flatten), Discrete, MultiDiscrete). else: if framework == "jax": raise NotImplementedError("No non-FC default net for JAX yet!") return ComplexNet @staticmethod def _get_multi_action_distribution(dist_class, action_space, config, framework): # In case the custom distribution is a child of MultiActionDistr. # If users want to completely ignore the suggested child # distributions, they should simply do so in their custom class' # constructor. if issubclass( dist_class, (MultiActionDistribution, TorchMultiActionDistribution) ): flat_action_space = flatten_space(action_space) child_dists_and_in_lens = tree.map_structure( lambda s: ModelCatalog.get_action_dist(s, config, framework=framework), flat_action_space, ) child_dists = [e[0] for e in child_dists_and_in_lens] input_lens = [int(e[1]) for e in child_dists_and_in_lens] return ( partial( dist_class, action_space=action_space, child_distributions=child_dists, input_lens=input_lens, ), int(sum(input_lens)), ) return dist_class, dist_class.required_model_output_shape(action_space, config) @staticmethod def _validate_config( config: ModelConfigDict, action_space: gym.spaces.Space, framework: str ) -> None: """Validates a given model config dict. Args: config: The "model" sub-config dict within the Trainer's config dict. action_space: The action space of the model, whose config are validated. framework: One of "jax", "tf2", "tf", "tfe", or "torch". Raises: ValueError: If something is wrong with the given config. """ # Soft-deprecate custom preprocessors. if config.get("custom_preprocessor") is not None: deprecation_warning( old="model.custom_preprocessor", new="gym.ObservationWrapper around your env or handle complex " "inputs inside your Model", error=False, ) if config.get("use_attention") and config.get("use_lstm"): raise ValueError( "Only one of `use_lstm` or `use_attention` may be set to True!" ) # For complex action spaces, only allow prev action inputs to # LSTMs and attention nets iff `_disable_action_flattening=True`. # TODO: `_disable_action_flattening=True` will be the default in # the future. if ( ( config.get("lstm_use_prev_action") or config.get("attention_use_n_prev_actions", 0) > 0 ) and not config.get("_disable_action_flattening") and isinstance(action_space, (Tuple, Dict)) ): raise ValueError( "For your complex action space (Tuple|Dict) and your model's " "`prev-actions` setup of your model, you must set " "`_disable_action_flattening=True` in your main config dict!" ) if framework == "jax": if config.get("use_attention"): raise ValueError( "`use_attention` not available for framework=jax so far!" ) elif config.get("use_lstm"): raise ValueError("`use_lstm` not available for framework=jax so far!")