[rllib] Type annotations for model classes (#9646)

This commit is contained in:
Eric Liang 2020-07-24 12:01:46 -07:00 committed by GitHub
parent 03709d67cb
commit 590943a499
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 153 additions and 120 deletions

View file

@ -44,7 +44,7 @@ MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable # yapf: disable
# __sphinx_doc_begin__ # __sphinx_doc_begin__
COMMON_CONFIG = { COMMON_CONFIG: TrainerConfigDict = {
# === Settings for Rollout Worker processes === # === Settings for Rollout Worker processes ===
# Number of rollout worker actors to create for parallel sampling. Setting # Number of rollout worker actors to create for parallel sampling. Setting
# this to 0 will force rollouts to be done in the trainer actor. # this to 0 will force rollouts to be done in the trainer actor.

View file

@ -1,4 +1,9 @@
import numpy as np
import gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.types import TensorType, List, Union, ModelConfigDict
@DeveloperAPI @DeveloperAPI
@ -11,7 +16,7 @@ class ActionDistribution:
""" """
@DeveloperAPI @DeveloperAPI
def __init__(self, inputs, model): def __init__(self, inputs: List[TensorType], model: ModelV2):
"""Initialize the action dist. """Initialize the action dist.
Arguments: Arguments:
@ -25,12 +30,12 @@ class ActionDistribution:
self.model = model self.model = model
@DeveloperAPI @DeveloperAPI
def sample(self): def sample(self) -> TensorType:
"""Draw a sample from the action distribution.""" """Draw a sample from the action distribution."""
raise NotImplementedError raise NotImplementedError
@DeveloperAPI @DeveloperAPI
def deterministic_sample(self): def deterministic_sample(self) -> TensorType:
""" """
Get the deterministic "sampling" output from the distribution. Get the deterministic "sampling" output from the distribution.
This is usually the max likelihood output, i.e. mean for Normal, argmax This is usually the max likelihood output, i.e. mean for Normal, argmax
@ -39,26 +44,26 @@ class ActionDistribution:
raise NotImplementedError raise NotImplementedError
@DeveloperAPI @DeveloperAPI
def sampled_action_logp(self): def sampled_action_logp(self) -> TensorType:
"""Returns the log probability of the last sampled action.""" """Returns the log probability of the last sampled action."""
raise NotImplementedError raise NotImplementedError
@DeveloperAPI @DeveloperAPI
def logp(self, x): def logp(self, x: TensorType) -> TensorType:
"""The log-likelihood of the action distribution.""" """The log-likelihood of the action distribution."""
raise NotImplementedError raise NotImplementedError
@DeveloperAPI @DeveloperAPI
def kl(self, other): def kl(self, other: "ActionDistribution") -> TensorType:
"""The KL-divergence between two action distributions.""" """The KL-divergence between two action distributions."""
raise NotImplementedError raise NotImplementedError
@DeveloperAPI @DeveloperAPI
def entropy(self): def entropy(self) -> TensorType:
"""The entropy of the action distribution.""" """The entropy of the action distribution."""
raise NotImplementedError raise NotImplementedError
def multi_kl(self, other): def multi_kl(self, other: "ActionDistribution") -> TensorType:
"""The KL-divergence between two action distributions. """The KL-divergence between two action distributions.
This differs from kl() in that it can return an array for This differs from kl() in that it can return an array for
@ -66,7 +71,7 @@ class ActionDistribution:
""" """
return self.kl(other) return self.kl(other)
def multi_entropy(self): def multi_entropy(self) -> TensorType:
"""The entropy of the action distribution. """The entropy of the action distribution.
This differs from entropy() in that it can return an array for This differs from entropy() in that it can return an array for
@ -76,7 +81,9 @@ class ActionDistribution:
@DeveloperAPI @DeveloperAPI
@staticmethod @staticmethod
def required_model_output_shape(action_space, model_config): def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
"""Returns the required shape of an input parameter tensor for a """Returns the required shape of an input parameter tensor for a
particular action space and an optional dict of distribution-specific particular action space and an optional dict of distribution-specific
options. options.

View file

@ -3,12 +3,13 @@ import gym
import logging import logging
import numpy as np import numpy as np
import tree import tree
from typing import List
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
RLLIB_ACTION_DIST, _global_registry RLLIB_ACTION_DIST, _global_registry
from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.lstm_v1 import LSTM from ray.rllib.models.tf.lstm_v1 import LSTM
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
@ -26,6 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import flatten_space from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.types import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# yapf: disable # yapf: disable
# __sphinx_doc_begin__ # __sphinx_doc_begin__
MODEL_DEFAULTS = { MODEL_DEFAULTS: ModelConfigDict = {
# === Built-in options === # === Built-in options ===
# Filter config. List of [out_channels, kernel, stride] for each filter # Filter config. List of [out_channels, kernel, stride] for each filter
"conv_filters": None, "conv_filters": None,
@ -114,11 +116,11 @@ class ModelCatalog:
@staticmethod @staticmethod
@DeveloperAPI @DeveloperAPI
def get_action_dist(action_space, def get_action_dist(action_space: gym.Space,
config, config: ModelConfigDict,
dist_type=None, dist_type: str = None,
framework="tf", framework: str = "tf",
**kwargs): **kwargs) -> (type, int):
"""Returns a distribution class and size for the given action space. """Returns a distribution class and size for the given action space.
Args: Args:
@ -209,7 +211,7 @@ class ModelCatalog:
@staticmethod @staticmethod
@DeveloperAPI @DeveloperAPI
def get_action_shape(action_space): def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]):
"""Returns action tensor dtype and shape for the action space. """Returns action tensor dtype and shape for the action space.
Args: Args:
@ -243,7 +245,8 @@ class ModelCatalog:
@staticmethod @staticmethod
@DeveloperAPI @DeveloperAPI
def get_action_placeholder(action_space, name="action"): def get_action_placeholder(action_space: gym.Space,
name: str = "action") -> TensorType:
"""Returns an action placeholder consistent with the action space """Returns an action placeholder consistent with the action space
Args: Args:
@ -260,15 +263,15 @@ class ModelCatalog:
@staticmethod @staticmethod
@DeveloperAPI @DeveloperAPI
def get_model_v2(obs_space, def get_model_v2(obs_space: gym.Space,
action_space, action_space: gym.Space,
num_outputs, num_outputs: int,
model_config, model_config: ModelConfigDict,
framework="tf", framework: str = "tf",
name="default_model", name: str = "default_model",
model_interface=None, model_interface: type = None,
default_model=None, default_model: type = None,
**model_kwargs): **model_kwargs) -> ModelV2:
"""Returns a suitable model compatible with given spaces and output. """Returns a suitable model compatible with given spaces and output.
Args: Args:
@ -420,7 +423,7 @@ class ModelCatalog:
@staticmethod @staticmethod
@DeveloperAPI @DeveloperAPI
def get_preprocessor(env, options=None): def get_preprocessor(env: gym.Env, options: dict = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given env. """Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space(). This is a wrapper for get_preprocessor_for_space().
@ -431,7 +434,8 @@ class ModelCatalog:
@staticmethod @staticmethod
@DeveloperAPI @DeveloperAPI
def get_preprocessor_for_space(observation_space, options=None): def get_preprocessor_for_space(observation_space: gym.Space,
options: dict = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given observation space. """Returns a suitable preprocessor for the given observation space.
Args: Args:
@ -469,7 +473,8 @@ class ModelCatalog:
@staticmethod @staticmethod
@PublicAPI @PublicAPI
def register_custom_preprocessor(preprocessor_name, preprocessor_class): def register_custom_preprocessor(preprocessor_name: str,
preprocessor_class: type) -> None:
"""Register a custom preprocessor class by name. """Register a custom preprocessor class by name.
The preprocessor can be later used by specifying The preprocessor can be later used by specifying
@ -484,7 +489,7 @@ class ModelCatalog:
@staticmethod @staticmethod
@PublicAPI @PublicAPI
def register_custom_model(model_name, model_class): def register_custom_model(model_name: str, model_class: type) -> None:
"""Register a custom model class by name. """Register a custom model class by name.
The model can be later used by specifying {"custom_model": model_name} The model can be later used by specifying {"custom_model": model_name}
@ -498,7 +503,8 @@ class ModelCatalog:
@staticmethod @staticmethod
@PublicAPI @PublicAPI
def register_custom_action_dist(action_dist_name, action_dist_class): def register_custom_action_dist(action_dist_name: str,
action_dist_class: type) -> None:
"""Register a custom action distribution class by name. """Register a custom action distribution class by name.
The model can be later used by specifying The model can be later used by specifying
@ -512,7 +518,7 @@ class ModelCatalog:
action_dist_class) action_dist_class)
@staticmethod @staticmethod
def _wrap_if_needed(model_cls, model_interface): def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
assert issubclass(model_cls, ModelV2), model_cls assert issubclass(model_cls, ModelV2), model_cls
if not model_interface or issubclass(model_cls, model_interface): if not model_interface or issubclass(model_cls, model_interface):
@ -608,10 +614,3 @@ class ModelCatalog:
return FullyConnectedNetwork(input_dict, obs_space, action_space, return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options) num_outputs, options)
@staticmethod
def get_torch_model(obs_space,
num_outputs,
options=None,
default_model_cls=None):
raise DeprecationWarning("Please use get_model_v2() instead.")

View file

@ -1,6 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
import contextlib
import gym import gym
from typing import Dict import numpy as np
from typing import Dict, List, Any, Union
from ray.rllib.models.preprocessors import get_preprocessor, \ from ray.rllib.models.preprocessors import get_preprocessor, \
RepeatedValuesPreprocessor RepeatedValuesPreprocessor
@ -11,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType TensorType
from ray.rllib.utils.spaces.repeated import Repeated from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.utils.types import ModelConfigDict from ray.rllib.utils.types import ModelConfigDict, TensorStructType
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch() torch, _ = try_import_torch()
@ -29,13 +31,9 @@ class ModelV2:
value_function() -> V(s) value_function() -> V(s)
""" """
def __init__(self, def __init__(self, obs_space: gym.spaces.Space,
obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int,
action_space: gym.spaces.Space, model_config: ModelConfigDict, name: str, framework: str):
num_outputs: int,
model_config: ModelConfigDict,
name: str,
framework: str):
"""Initializes a ModelV2 object. """Initializes a ModelV2 object.
This method should create any variables used by the model. This method should create any variables used by the model.
@ -62,7 +60,7 @@ class ModelV2:
self._last_output = None self._last_output = None
@PublicAPI @PublicAPI
def get_initial_state(self): def get_initial_state(self) -> List[np.ndarray]:
"""Get the initial recurrent state values for the model. """Get the initial recurrent state values for the model.
Returns: Returns:
@ -79,7 +77,9 @@ class ModelV2:
return [] return []
@PublicAPI @PublicAPI
def forward(self, input_dict, state, seq_lens): def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""Call the model with the given input tensors and state. """Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by Any complex observations (dicts, tuples, etc.) will be unpacked by
@ -103,7 +103,7 @@ class ModelV2:
Returns: Returns:
(outputs, state): The model output tensor of size (outputs, state): The model output tensor of size
[BATCH, num_outputs] [BATCH, num_outputs], and the new RNN state.
Examples: Examples:
>>> def forward(self, input_dict, state, seq_lens): >>> def forward(self, input_dict, state, seq_lens):
@ -114,7 +114,7 @@ class ModelV2:
raise NotImplementedError raise NotImplementedError
@PublicAPI @PublicAPI
def value_function(self): def value_function(self) -> TensorType:
"""Returns the value function output for the most recent forward pass. """Returns the value function output for the most recent forward pass.
Note that a `forward` call has to be performed first, before this Note that a `forward` call has to be performed first, before this
@ -127,7 +127,8 @@ class ModelV2:
raise NotImplementedError raise NotImplementedError
@PublicAPI @PublicAPI
def custom_loss(self, policy_loss, loss_inputs): def custom_loss(self, policy_loss: TensorType,
loss_inputs: Dict[str, TensorType]) -> TensorType:
"""Override to customize the loss function used to optimize this model. """Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining This can be used to incorporate self-supervised losses (by defining
@ -149,7 +150,7 @@ class ModelV2:
return policy_loss return policy_loss
@PublicAPI @PublicAPI
def metrics(self): def metrics(self) -> Dict[str, TensorType]:
"""Override to return custom metrics from your model. """Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e., The stats will be reported as part of the learner stats, i.e.,
@ -164,7 +165,11 @@ class ModelV2:
""" """
return {} return {}
def __call__(self, input_dict, state=None, seq_lens=None): def __call__(
self,
input_dict: Dict[str, TensorType],
state: List[Any] = None,
seq_lens: TensorType = None) -> (TensorType, List[TensorType]):
"""Call the model with the given input tensors and state. """Call the model with the given input tensors and state.
This is the method used by RLlib to execute the forward pass. It calls This is the method used by RLlib to execute the forward pass. It calls
@ -218,7 +223,8 @@ class ModelV2:
return outputs, state return outputs, state
@PublicAPI @PublicAPI
def from_batch(self, train_batch, is_training=True): def from_batch(self, train_batch: SampleBatch,
is_training: bool = True) -> (TensorType, List[TensorType]):
"""Convenience function that calls this model with a tensor batch. """Convenience function that calls this model with a tensor batch.
All this does is unpack the tensor batch to call this model with the All this does is unpack the tensor batch to call this model with the
@ -241,8 +247,7 @@ class ModelV2:
return self.__call__(input_dict, states, train_batch.get("seq_lens")) return self.__call__(input_dict, states, train_batch.get("seq_lens"))
def get_view_requirements( def get_view_requirements(
self, self, is_training: bool = False) -> Dict[str, ViewRequirement]:
is_training: bool = False) -> Dict[str, ViewRequirement]:
"""Returns a list of ViewRequirements for this Model (or None). """Returns a list of ViewRequirements for this Model (or None).
Note: This is an experimental API method. Note: This is an experimental API method.
@ -266,13 +271,13 @@ class ModelV2:
# Single requirement: Pass current obs as input. # Single requirement: Pass current obs as input.
return { return {
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0), SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
SampleBatch.PREV_ACTIONS: SampleBatch.PREV_ACTIONS: ViewRequirement(
ViewRequirement(SampleBatch.ACTIONS, timesteps=-1), SampleBatch.ACTIONS, timesteps=-1),
SampleBatch.PREV_REWARDS: SampleBatch.PREV_REWARDS: ViewRequirement(
ViewRequirement(SampleBatch.REWARDS, timesteps=-1), SampleBatch.REWARDS, timesteps=-1),
} }
def import_from_h5(self, h5_file): def import_from_h5(self, h5_file: str) -> None:
"""Imports weights from an h5 file. """Imports weights from an h5 file.
Args: Args:
@ -287,17 +292,18 @@ class ModelV2:
raise NotImplementedError raise NotImplementedError
@PublicAPI @PublicAPI
def last_output(self): def last_output(self) -> TensorType:
"""Returns the last output returned from calling the model.""" """Returns the last output returned from calling the model."""
return self._last_output return self._last_output
@PublicAPI @PublicAPI
def context(self): def context(self) -> contextlib.AbstractContextManager:
"""Returns a contextmanager for the current forward pass.""" """Returns a contextmanager for the current forward pass."""
return NullContextManager() return NullContextManager()
@PublicAPI @PublicAPI
def variables(self, as_dict=False): def variables(self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Returns the list (or a dict) of variables for this model. """Returns the list (or a dict) of variables for this model.
Args: Args:
@ -311,7 +317,9 @@ class ModelV2:
raise NotImplementedError raise NotImplementedError
@PublicAPI @PublicAPI
def trainable_variables(self, as_dict=False): def trainable_variables(
self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Returns the list of trainable variables for this model. """Returns the list of trainable variables for this model.
Args: Args:
@ -340,7 +348,7 @@ class NullContextManager:
@DeveloperAPI @DeveloperAPI
def flatten(obs, framework): def flatten(obs: TensorType, framework: str) -> TensorType:
"""Flatten the given tensor.""" """Flatten the given tensor."""
if framework in ["tf2", "tf", "tfe"]: if framework in ["tf2", "tf", "tfe"]:
return tf1.keras.layers.Flatten()(obs) return tf1.keras.layers.Flatten()(obs)
@ -354,7 +362,7 @@ def flatten(obs, framework):
@DeveloperAPI @DeveloperAPI
def restore_original_dimensions(obs: TensorType, def restore_original_dimensions(obs: TensorType,
obs_space: gym.spaces.Space, obs_space: gym.spaces.Space,
tensorlib=tf): tensorlib: Any = tf) -> TensorStructType:
"""Unpacks Dict and Tuple space observations into their original form. """Unpacks Dict and Tuple space observations into their original form.
This is needed since we flatten Dict and Tuple observations in transit This is needed since we flatten Dict and Tuple observations in transit
@ -388,7 +396,8 @@ def restore_original_dimensions(obs: TensorType,
_cache = {} _cache = {}
def _unpack_obs(obs, space, tensorlib=tf): def _unpack_obs(obs: TensorType, space: gym.Space,
tensorlib: Any = tf) -> TensorStructType:
"""Unpack a flattened Dict or Tuple observation array/tensor. """Unpack a flattened Dict or Tuple observation array/tensor.
Args: Args:

View file

@ -3,6 +3,7 @@ import cv2
import logging import logging
import numpy as np import numpy as np
import gym import gym
from typing import Any, List
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.spaces.repeated import Repeated from ray.rllib.utils.spaces.repeated import Repeated
@ -19,11 +20,11 @@ class Preprocessor:
"""Defines an abstract observation preprocessor function. """Defines an abstract observation preprocessor function.
Attributes: Attributes:
shape (obj): Shape of the preprocessed output. shape (List[int]): Shape of the preprocessed output.
""" """
@PublicAPI @PublicAPI
def __init__(self, obs_space, options=None): def __init__(self, obs_space: gym.Space, options: dict = None):
legacy_patch_shapes(obs_space) legacy_patch_shapes(obs_space)
self._obs_space = obs_space self._obs_space = obs_space
if not options: if not options:
@ -36,20 +37,20 @@ class Preprocessor:
self._i = 0 self._i = 0
@PublicAPI @PublicAPI
def _init_shape(self, obs_space, options): def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
"""Returns the shape after preprocessing.""" """Returns the shape after preprocessing."""
raise NotImplementedError raise NotImplementedError
@PublicAPI @PublicAPI
def transform(self, observation): def transform(self, observation: Any) -> np.ndarray:
"""Returns the preprocessed observation.""" """Returns the preprocessed observation."""
raise NotImplementedError raise NotImplementedError
def write(self, observation, array, offset): def write(self, observation: Any, array: np.ndarray, offset: int) -> None:
"""Alternative to transform for more efficient flattening.""" """Alternative to transform for more efficient flattening."""
array[offset:offset + self._size] = self.transform(observation) array[offset:offset + self._size] = self.transform(observation)
def check_shape(self, observation): def check_shape(self, observation: Any) -> None:
"""Checks the shape of the given observation.""" """Checks the shape of the given observation."""
if self._i % VALIDATION_INTERVAL == 0: if self._i % VALIDATION_INTERVAL == 0:
if type(observation) is list and isinstance( if type(observation) is list and isinstance(
@ -69,12 +70,12 @@ class Preprocessor:
@property @property
@PublicAPI @PublicAPI
def size(self): def size(self) -> int:
return self._size return self._size
@property @property
@PublicAPI @PublicAPI
def observation_space(self): def observation_space(self) -> gym.Space:
obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32) obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32)
# Stash the unwrapped space so that we can unwrap dict and tuple spaces # Stash the unwrapped space so that we can unwrap dict and tuple spaces
# automatically in model.py # automatically in model.py
@ -286,7 +287,7 @@ class RepeatedValuesPreprocessor(Preprocessor):
@PublicAPI @PublicAPI
def get_preprocessor(space): def get_preprocessor(space: gym.Space) -> type:
"""Returns an appropriate preprocessor class for the given space.""" """Returns an appropriate preprocessor class for the given space."""
legacy_patch_shapes(space) legacy_patch_shapes(space)
@ -310,7 +311,7 @@ def get_preprocessor(space):
return preprocessor return preprocessor
def legacy_patch_shapes(space): def legacy_patch_shapes(space: gym.Space) -> List[int]:
"""Assigns shapes to spaces that don't have shapes. """Assigns shapes to spaces that don't have shapes.
This is only needed for older gym versions that don't set shapes properly This is only needed for older gym versions that don't set shapes properly

View file

@ -1,7 +1,7 @@
from typing import List from typing import List
from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.framework import TensorType from ray.rllib.utils.framework import TensorType, TensorStructType
@PublicAPI @PublicAPI
@ -48,7 +48,7 @@ class RepeatedValues:
self.max_len = max_len self.max_len = max_len
self._unbatched_repr = None self._unbatched_repr = None
def unbatch_all(self): def unbatch_all(self) -> List[List[TensorType]]:
"""Unbatch both the repeat and batch dimensions into Python lists. """Unbatch both the repeat and batch dimensions into Python lists.
This is only supported in PyTorch / TF eager mode. This is only supported in PyTorch / TF eager mode.
@ -64,7 +64,7 @@ class RepeatedValues:
>>> print(max(len(x) for x in items) <= N) >>> print(max(len(x) for x in items) <= N)
True True
>>> print(items) >>> print(items)
... [<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>], ... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
... ... ... ...
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>], ... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
... ... ... ...
@ -96,7 +96,7 @@ class RepeatedValues:
return self._unbatched_repr return self._unbatched_repr
def unbatch_repeat_dim(self): def unbatch_repeat_dim(self) -> List[TensorType]:
"""Unbatches the repeat dimension (the one `max_len` in size). """Unbatches the repeat dimension (the one `max_len` in size).
This removes the repeat dimension. The result will be a Python list of This removes the repeat dimension. The result will be a Python list of
@ -120,7 +120,7 @@ class RepeatedValues:
return repr(self) return repr(self)
def _get_batch_dim_helper(v): def _get_batch_dim_helper(v: TensorStructType) -> int:
"""Tries to find the batch dimension size of v, or None.""" """Tries to find the batch dimension size of v, or None."""
if isinstance(v, dict): if isinstance(v, dict):
for u in v.values(): for u in v.values():
@ -136,7 +136,7 @@ def _get_batch_dim_helper(v):
return B return B
def _unbatch_helper(v, max_len): def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
"""Recursively unpacks the repeat dimension (max_len).""" """Recursively unpacks the repeat dimension (max_len)."""
if isinstance(v, dict): if isinstance(v, dict):
return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()} return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
@ -152,7 +152,8 @@ def _unbatch_helper(v, max_len):
return [v[:, i, ...] for i in range(max_len)] return [v[:, i, ...] for i in range(max_len)]
def _batch_index_helper(v, i, j): def _batch_index_helper(v: TensorStructType, i: int,
j: int) -> TensorStructType:
"""Selects the item at the ith batch index and jth repetition.""" """Selects the item at the ith batch index and jth repetition."""
if isinstance(v, dict): if isinstance(v, dict):
return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()} return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}

View file

@ -4,11 +4,13 @@ import functools
import tree import tree
from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \ from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
SMALL_NUMBER SMALL_NUMBER
from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.types import TensorType, List
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp() tfp = try_import_tfp()
@ -18,14 +20,14 @@ tfp = try_import_tfp()
class TFActionDistribution(ActionDistribution): class TFActionDistribution(ActionDistribution):
"""TF-specific extensions for building action distributions.""" """TF-specific extensions for building action distributions."""
@DeveloperAPI @override(ActionDistribution)
def __init__(self, inputs, model): def __init__(self, inputs: List[TensorType], model: ModelV2):
super().__init__(inputs, model) super().__init__(inputs, model)
self.sample_op = self._build_sample_op() self.sample_op = self._build_sample_op()
self.sampled_action_logp_op = self.logp(self.sample_op) self.sampled_action_logp_op = self.logp(self.sample_op)
@DeveloperAPI @DeveloperAPI
def _build_sample_op(self): def _build_sample_op(self) -> TensorType:
"""Implement this instead of sample(), to enable op reuse. """Implement this instead of sample(), to enable op reuse.
This is needed since the sample op is non-deterministic and is shared This is needed since the sample op is non-deterministic and is shared
@ -34,12 +36,12 @@ class TFActionDistribution(ActionDistribution):
raise NotImplementedError raise NotImplementedError
@override(ActionDistribution) @override(ActionDistribution)
def sample(self): def sample(self) -> TensorType:
"""Draw a sample from the action distribution.""" """Draw a sample from the action distribution."""
return self.sample_op return self.sample_op
@override(ActionDistribution) @override(ActionDistribution)
def sampled_action_logp(self): def sampled_action_logp(self) -> TensorType:
"""Returns the log probability of the sampled action.""" """Returns the log probability of the sampled action."""
return self.sampled_action_logp_op return self.sampled_action_logp_op
@ -242,9 +244,8 @@ class DiagGaussian(TFActionDistribution):
assert isinstance(other, DiagGaussian) assert isinstance(other, DiagGaussian)
return tf.reduce_sum( return tf.reduce_sum(
other.log_std - self.log_std + other.log_std - self.log_std +
(tf.math.square(self.std) + (tf.math.square(self.std) + tf.math.square(self.mean - other.mean))
tf.math.square(self.mean - other.mean)) / / (2.0 * tf.math.square(other.std)) - 0.5,
(2.0 * tf.math.square(other.std)) - 0.5,
axis=1) axis=1)
@override(ActionDistribution) @override(ActionDistribution)

View file

@ -1,6 +1,11 @@
import contextlib
import gym
from typing import List
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.types import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -12,8 +17,9 @@ class TFModelV2(ModelV2):
Note that this class by itself is not a valid model unless you Note that this class by itself is not a valid model unless you
implement forward() in a subclass.""" implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config, def __init__(self, obs_space: gym.spaces.Space,
name): action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
"""Initialize a TFModelV2. """Initialize a TFModelV2.
Here is an example implementation for a subclass Here is an example implementation for a subclass
@ -44,31 +50,31 @@ class TFModelV2(ModelV2):
else: else:
self.graph = tf1.get_default_graph() self.graph = tf1.get_default_graph()
def context(self): def context(self) -> contextlib.AbstractContextManager:
"""Returns a contextmanager for the current TF graph.""" """Returns a contextmanager for the current TF graph."""
if self.graph: if self.graph:
return self.graph.as_default() return self.graph.as_default()
else: else:
return ModelV2.context(self) return ModelV2.context(self)
def update_ops(self): def update_ops(self) -> List[TensorType]:
"""Return the list of update ops for this model. """Return the list of update ops for this model.
For example, this should include any BatchNorm update ops.""" For example, this should include any BatchNorm update ops."""
return [] return []
def register_variables(self, variables): def register_variables(self, variables: List[TensorType]) -> None:
"""Register the given list of variables with this model.""" """Register the given list of variables with this model."""
self.var_list.extend(variables) self.var_list.extend(variables)
@override(ModelV2) @override(ModelV2)
def variables(self, as_dict=False): def variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict: if as_dict:
return {v.name: v for v in self.var_list} return {v.name: v for v in self.var_list}
return list(self.var_list) return list(self.var_list)
@override(ModelV2) @override(ModelV2)
def trainable_variables(self, as_dict=False): def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict: if as_dict:
return { return {
k: v k: v

View file

@ -4,12 +4,14 @@ import numpy as np
import tree import tree
from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \ from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT MAX_LOG_NN_OUTPUT
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.torch_ops import atanh from ray.rllib.utils.torch_ops import atanh
from ray.rllib.utils.types import TensorType, List
torch, nn = try_import_torch() torch, nn = try_import_torch()
@ -18,7 +20,7 @@ class TorchDistributionWrapper(ActionDistribution):
"""Wrapper class for torch.distributions.""" """Wrapper class for torch.distributions."""
@override(ActionDistribution) @override(ActionDistribution)
def __init__(self, inputs, model): def __init__(self, inputs: List[TensorType], model: ModelV2):
if not isinstance(inputs, torch.Tensor): if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs) inputs = torch.Tensor(inputs)
super().__init__(inputs, model) super().__init__(inputs, model)
@ -26,24 +28,24 @@ class TorchDistributionWrapper(ActionDistribution):
self.last_sample = None self.last_sample = None
@override(ActionDistribution) @override(ActionDistribution)
def logp(self, actions): def logp(self, actions: TensorType) -> TensorType:
return self.dist.log_prob(actions) return self.dist.log_prob(actions)
@override(ActionDistribution) @override(ActionDistribution)
def entropy(self): def entropy(self) -> TensorType:
return self.dist.entropy() return self.dist.entropy()
@override(ActionDistribution) @override(ActionDistribution)
def kl(self, other): def kl(self, other: ActionDistribution) -> TensorType:
return torch.distributions.kl.kl_divergence(self.dist, other.dist) return torch.distributions.kl.kl_divergence(self.dist, other.dist)
@override(ActionDistribution) @override(ActionDistribution)
def sample(self): def sample(self) -> TensorType:
self.last_sample = self.dist.sample() self.last_sample = self.dist.sample()
return self.last_sample return self.last_sample
@override(ActionDistribution) @override(ActionDistribution)
def sampled_action_logp(self): def sampled_action_logp(self) -> TensorType:
assert self.last_sample is not None assert self.last_sample is not None
return self.logp(self.last_sample) return self.logp(self.last_sample)

View file

@ -1,6 +1,10 @@
import gym
from typing import List
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.types import ModelConfigDict, TensorType
_, nn = try_import_torch() _, nn = try_import_torch()
@ -12,8 +16,9 @@ class TorchModelV2(ModelV2):
Note that this class by itself is not a valid model unless you Note that this class by itself is not a valid model unless you
inherit from nn.Module and implement forward() in a subclass.""" inherit from nn.Module and implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config, def __init__(self, obs_space: gym.spaces.Space,
name): action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
"""Initialize a TorchModelV2. """Initialize a TorchModelV2.
Here is an example implementation for a subclass Here is an example implementation for a subclass
@ -42,13 +47,13 @@ class TorchModelV2(ModelV2):
framework="torch") framework="torch")
@override(ModelV2) @override(ModelV2)
def variables(self, as_dict=False): def variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict: if as_dict:
return self.state_dict() return self.state_dict()
return list(self.parameters()) return list(self.parameters())
@override(ModelV2) @override(ModelV2)
def trainable_variables(self, as_dict=False): def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict: if as_dict:
return { return {
k: v k: v

View file

@ -1,8 +1,11 @@
import numpy as np import numpy as np
from typing import Dict, Optional from typing import Dict, Optional, List, TYPE_CHECKING
from ray.rllib.utils.types import TensorType from ray.rllib.utils.types import TensorType
if TYPE_CHECKING:
from ray.rllib.models import ModelV2
class ViewRequirement: class ViewRequirement:
"""Single view requirement (for one column in a ModelV2 input_dict). """Single view requirement (for one column in a ModelV2 input_dict).
@ -57,10 +60,9 @@ class ViewRequirement:
# "time_major", # "time_major",
def get_trajectory_view( def get_trajectory_view(model: "ModelV2",
model, trajectories: List["Trajectory"],
trajectories, is_training: bool = False) -> Dict[str, TensorType]:
is_training: bool = False) -> Dict[str, TensorType]:
"""Returns an input_dict for a Model's forward pass given some data. """Returns an input_dict for a Model's forward pass given some data.
Args: Args: