From 590943a49951892d2c8d590ebde5d1bcfbbc0292 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 24 Jul 2020 12:01:46 -0700 Subject: [PATCH] [rllib] Type annotations for model classes (#9646) --- rllib/agents/trainer.py | 2 +- rllib/models/action_dist.py | 27 ++++++---- rllib/models/catalog.py | 61 +++++++++++---------- rllib/models/modelv2.py | 71 ++++++++++++++----------- rllib/models/preprocessors.py | 21 ++++---- rllib/models/repeated_values.py | 15 +++--- rllib/models/tf/tf_action_dist.py | 17 +++--- rllib/models/tf/tf_modelv2.py | 20 ++++--- rllib/models/torch/torch_action_dist.py | 14 ++--- rllib/models/torch/torch_modelv2.py | 13 +++-- rllib/policy/trajectory_view.py | 12 +++-- 11 files changed, 153 insertions(+), 120 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 307e16be5..48fcf31be 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -44,7 +44,7 @@ MAX_WORKER_FAILURE_RETRIES = 3 # yapf: disable # __sphinx_doc_begin__ -COMMON_CONFIG = { +COMMON_CONFIG: TrainerConfigDict = { # === Settings for Rollout Worker processes === # Number of rollout worker actors to create for parallel sampling. Setting # this to 0 will force rollouts to be done in the trainer actor. diff --git a/rllib/models/action_dist.py b/rllib/models/action_dist.py index 5ee4f2e7c..5c99509ad 100644 --- a/rllib/models/action_dist.py +++ b/rllib/models/action_dist.py @@ -1,4 +1,9 @@ +import numpy as np +import gym + +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.types import TensorType, List, Union, ModelConfigDict @DeveloperAPI @@ -11,7 +16,7 @@ class ActionDistribution: """ @DeveloperAPI - def __init__(self, inputs, model): + def __init__(self, inputs: List[TensorType], model: ModelV2): """Initialize the action dist. Arguments: @@ -25,12 +30,12 @@ class ActionDistribution: self.model = model @DeveloperAPI - def sample(self): + def sample(self) -> TensorType: """Draw a sample from the action distribution.""" raise NotImplementedError @DeveloperAPI - def deterministic_sample(self): + def deterministic_sample(self) -> TensorType: """ Get the deterministic "sampling" output from the distribution. This is usually the max likelihood output, i.e. mean for Normal, argmax @@ -39,26 +44,26 @@ class ActionDistribution: raise NotImplementedError @DeveloperAPI - def sampled_action_logp(self): + def sampled_action_logp(self) -> TensorType: """Returns the log probability of the last sampled action.""" raise NotImplementedError @DeveloperAPI - def logp(self, x): + def logp(self, x: TensorType) -> TensorType: """The log-likelihood of the action distribution.""" raise NotImplementedError @DeveloperAPI - def kl(self, other): + def kl(self, other: "ActionDistribution") -> TensorType: """The KL-divergence between two action distributions.""" raise NotImplementedError @DeveloperAPI - def entropy(self): + def entropy(self) -> TensorType: """The entropy of the action distribution.""" raise NotImplementedError - def multi_kl(self, other): + def multi_kl(self, other: "ActionDistribution") -> TensorType: """The KL-divergence between two action distributions. This differs from kl() in that it can return an array for @@ -66,7 +71,7 @@ class ActionDistribution: """ return self.kl(other) - def multi_entropy(self): + def multi_entropy(self) -> TensorType: """The entropy of the action distribution. This differs from entropy() in that it can return an array for @@ -76,7 +81,9 @@ class ActionDistribution: @DeveloperAPI @staticmethod - def required_model_output_shape(action_space, model_config): + def required_model_output_shape( + action_space: gym.Space, + model_config: ModelConfigDict) -> Union[int, np.ndarray]: """Returns the required shape of an input parameter tensor for a particular action space and an optional dict of distribution-specific options. diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 425bd623e..afe39e9cf 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -3,12 +3,13 @@ import gym import logging import numpy as np import tree +from typing import List from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ RLLIB_ACTION_DIST, _global_registry from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.preprocessors import get_preprocessor +from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork from ray.rllib.models.tf.lstm_v1 import LSTM from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper @@ -26,6 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import flatten_space +from ray.rllib.utils.types import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() @@ -33,7 +35,7 @@ logger = logging.getLogger(__name__) # yapf: disable # __sphinx_doc_begin__ -MODEL_DEFAULTS = { +MODEL_DEFAULTS: ModelConfigDict = { # === Built-in options === # Filter config. List of [out_channels, kernel, stride] for each filter "conv_filters": None, @@ -114,11 +116,11 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_action_dist(action_space, - config, - dist_type=None, - framework="tf", - **kwargs): + def get_action_dist(action_space: gym.Space, + config: ModelConfigDict, + dist_type: str = None, + framework: str = "tf", + **kwargs) -> (type, int): """Returns a distribution class and size for the given action space. Args: @@ -209,7 +211,7 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_action_shape(action_space): + def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]): """Returns action tensor dtype and shape for the action space. Args: @@ -243,7 +245,8 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_action_placeholder(action_space, name="action"): + def get_action_placeholder(action_space: gym.Space, + name: str = "action") -> TensorType: """Returns an action placeholder consistent with the action space Args: @@ -260,15 +263,15 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_model_v2(obs_space, - action_space, - num_outputs, - model_config, - framework="tf", - name="default_model", - model_interface=None, - default_model=None, - **model_kwargs): + def get_model_v2(obs_space: gym.Space, + action_space: gym.Space, + num_outputs: int, + model_config: ModelConfigDict, + framework: str = "tf", + name: str = "default_model", + model_interface: type = None, + default_model: type = None, + **model_kwargs) -> ModelV2: """Returns a suitable model compatible with given spaces and output. Args: @@ -420,7 +423,7 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_preprocessor(env, options=None): + def get_preprocessor(env: gym.Env, options: dict = None) -> Preprocessor: """Returns a suitable preprocessor for the given env. This is a wrapper for get_preprocessor_for_space(). @@ -431,7 +434,8 @@ class ModelCatalog: @staticmethod @DeveloperAPI - def get_preprocessor_for_space(observation_space, options=None): + def get_preprocessor_for_space(observation_space: gym.Space, + options: dict = None) -> Preprocessor: """Returns a suitable preprocessor for the given observation space. Args: @@ -469,7 +473,8 @@ class ModelCatalog: @staticmethod @PublicAPI - def register_custom_preprocessor(preprocessor_name, preprocessor_class): + def register_custom_preprocessor(preprocessor_name: str, + preprocessor_class: type) -> None: """Register a custom preprocessor class by name. The preprocessor can be later used by specifying @@ -484,7 +489,7 @@ class ModelCatalog: @staticmethod @PublicAPI - def register_custom_model(model_name, model_class): + def register_custom_model(model_name: str, model_class: type) -> None: """Register a custom model class by name. The model can be later used by specifying {"custom_model": model_name} @@ -498,7 +503,8 @@ class ModelCatalog: @staticmethod @PublicAPI - def register_custom_action_dist(action_dist_name, action_dist_class): + def register_custom_action_dist(action_dist_name: str, + action_dist_class: type) -> None: """Register a custom action distribution class by name. The model can be later used by specifying @@ -512,7 +518,7 @@ class ModelCatalog: action_dist_class) @staticmethod - def _wrap_if_needed(model_cls, model_interface): + def _wrap_if_needed(model_cls: type, model_interface: type) -> type: assert issubclass(model_cls, ModelV2), model_cls if not model_interface or issubclass(model_cls, model_interface): @@ -608,10 +614,3 @@ class ModelCatalog: return FullyConnectedNetwork(input_dict, obs_space, action_space, num_outputs, options) - - @staticmethod - def get_torch_model(obs_space, - num_outputs, - options=None, - default_model_cls=None): - raise DeprecationWarning("Please use get_model_v2() instead.") diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 0ab3d50da..f66e29cc1 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -1,6 +1,8 @@ from collections import OrderedDict +import contextlib import gym -from typing import Dict +import numpy as np +from typing import Dict, List, Any, Union from ray.rllib.models.preprocessors import get_preprocessor, \ RepeatedValuesPreprocessor @@ -11,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated -from ray.rllib.utils.types import ModelConfigDict +from ray.rllib.utils.types import ModelConfigDict, TensorStructType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -29,13 +31,9 @@ class ModelV2: value_function() -> V(s) """ - def __init__(self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - framework: str): + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str, framework: str): """Initializes a ModelV2 object. This method should create any variables used by the model. @@ -62,7 +60,7 @@ class ModelV2: self._last_output = None @PublicAPI - def get_initial_state(self): + def get_initial_state(self) -> List[np.ndarray]: """Get the initial recurrent state values for the model. Returns: @@ -79,7 +77,9 @@ class ModelV2: return [] @PublicAPI - def forward(self, input_dict, state, seq_lens): + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): """Call the model with the given input tensors and state. Any complex observations (dicts, tuples, etc.) will be unpacked by @@ -103,7 +103,7 @@ class ModelV2: Returns: (outputs, state): The model output tensor of size - [BATCH, num_outputs] + [BATCH, num_outputs], and the new RNN state. Examples: >>> def forward(self, input_dict, state, seq_lens): @@ -114,7 +114,7 @@ class ModelV2: raise NotImplementedError @PublicAPI - def value_function(self): + def value_function(self) -> TensorType: """Returns the value function output for the most recent forward pass. Note that a `forward` call has to be performed first, before this @@ -127,7 +127,8 @@ class ModelV2: raise NotImplementedError @PublicAPI - def custom_loss(self, policy_loss, loss_inputs): + def custom_loss(self, policy_loss: TensorType, + loss_inputs: Dict[str, TensorType]) -> TensorType: """Override to customize the loss function used to optimize this model. This can be used to incorporate self-supervised losses (by defining @@ -149,7 +150,7 @@ class ModelV2: return policy_loss @PublicAPI - def metrics(self): + def metrics(self) -> Dict[str, TensorType]: """Override to return custom metrics from your model. The stats will be reported as part of the learner stats, i.e., @@ -164,7 +165,11 @@ class ModelV2: """ return {} - def __call__(self, input_dict, state=None, seq_lens=None): + def __call__( + self, + input_dict: Dict[str, TensorType], + state: List[Any] = None, + seq_lens: TensorType = None) -> (TensorType, List[TensorType]): """Call the model with the given input tensors and state. This is the method used by RLlib to execute the forward pass. It calls @@ -218,7 +223,8 @@ class ModelV2: return outputs, state @PublicAPI - def from_batch(self, train_batch, is_training=True): + def from_batch(self, train_batch: SampleBatch, + is_training: bool = True) -> (TensorType, List[TensorType]): """Convenience function that calls this model with a tensor batch. All this does is unpack the tensor batch to call this model with the @@ -241,8 +247,7 @@ class ModelV2: return self.__call__(input_dict, states, train_batch.get("seq_lens")) def get_view_requirements( - self, - is_training: bool = False) -> Dict[str, ViewRequirement]: + self, is_training: bool = False) -> Dict[str, ViewRequirement]: """Returns a list of ViewRequirements for this Model (or None). Note: This is an experimental API method. @@ -266,13 +271,13 @@ class ModelV2: # Single requirement: Pass current obs as input. return { SampleBatch.CUR_OBS: ViewRequirement(timesteps=0), - SampleBatch.PREV_ACTIONS: - ViewRequirement(SampleBatch.ACTIONS, timesteps=-1), - SampleBatch.PREV_REWARDS: - ViewRequirement(SampleBatch.REWARDS, timesteps=-1), + SampleBatch.PREV_ACTIONS: ViewRequirement( + SampleBatch.ACTIONS, timesteps=-1), + SampleBatch.PREV_REWARDS: ViewRequirement( + SampleBatch.REWARDS, timesteps=-1), } - def import_from_h5(self, h5_file): + def import_from_h5(self, h5_file: str) -> None: """Imports weights from an h5 file. Args: @@ -287,17 +292,18 @@ class ModelV2: raise NotImplementedError @PublicAPI - def last_output(self): + def last_output(self) -> TensorType: """Returns the last output returned from calling the model.""" return self._last_output @PublicAPI - def context(self): + def context(self) -> contextlib.AbstractContextManager: """Returns a contextmanager for the current forward pass.""" return NullContextManager() @PublicAPI - def variables(self, as_dict=False): + def variables(self, as_dict: bool = False + ) -> Union[List[TensorType], Dict[str, TensorType]]: """Returns the list (or a dict) of variables for this model. Args: @@ -311,7 +317,9 @@ class ModelV2: raise NotImplementedError @PublicAPI - def trainable_variables(self, as_dict=False): + def trainable_variables( + self, as_dict: bool = False + ) -> Union[List[TensorType], Dict[str, TensorType]]: """Returns the list of trainable variables for this model. Args: @@ -340,7 +348,7 @@ class NullContextManager: @DeveloperAPI -def flatten(obs, framework): +def flatten(obs: TensorType, framework: str) -> TensorType: """Flatten the given tensor.""" if framework in ["tf2", "tf", "tfe"]: return tf1.keras.layers.Flatten()(obs) @@ -354,7 +362,7 @@ def flatten(obs, framework): @DeveloperAPI def restore_original_dimensions(obs: TensorType, obs_space: gym.spaces.Space, - tensorlib=tf): + tensorlib: Any = tf) -> TensorStructType: """Unpacks Dict and Tuple space observations into their original form. This is needed since we flatten Dict and Tuple observations in transit @@ -388,7 +396,8 @@ def restore_original_dimensions(obs: TensorType, _cache = {} -def _unpack_obs(obs, space, tensorlib=tf): +def _unpack_obs(obs: TensorType, space: gym.Space, + tensorlib: Any = tf) -> TensorStructType: """Unpack a flattened Dict or Tuple observation array/tensor. Args: diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 7c67a34cd..6834ea2d8 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -3,6 +3,7 @@ import cv2 import logging import numpy as np import gym +from typing import Any, List from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.spaces.repeated import Repeated @@ -19,11 +20,11 @@ class Preprocessor: """Defines an abstract observation preprocessor function. Attributes: - shape (obj): Shape of the preprocessed output. + shape (List[int]): Shape of the preprocessed output. """ @PublicAPI - def __init__(self, obs_space, options=None): + def __init__(self, obs_space: gym.Space, options: dict = None): legacy_patch_shapes(obs_space) self._obs_space = obs_space if not options: @@ -36,20 +37,20 @@ class Preprocessor: self._i = 0 @PublicAPI - def _init_shape(self, obs_space, options): + def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]: """Returns the shape after preprocessing.""" raise NotImplementedError @PublicAPI - def transform(self, observation): + def transform(self, observation: Any) -> np.ndarray: """Returns the preprocessed observation.""" raise NotImplementedError - def write(self, observation, array, offset): + def write(self, observation: Any, array: np.ndarray, offset: int) -> None: """Alternative to transform for more efficient flattening.""" array[offset:offset + self._size] = self.transform(observation) - def check_shape(self, observation): + def check_shape(self, observation: Any) -> None: """Checks the shape of the given observation.""" if self._i % VALIDATION_INTERVAL == 0: if type(observation) is list and isinstance( @@ -69,12 +70,12 @@ class Preprocessor: @property @PublicAPI - def size(self): + def size(self) -> int: return self._size @property @PublicAPI - def observation_space(self): + def observation_space(self) -> gym.Space: obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32) # Stash the unwrapped space so that we can unwrap dict and tuple spaces # automatically in model.py @@ -286,7 +287,7 @@ class RepeatedValuesPreprocessor(Preprocessor): @PublicAPI -def get_preprocessor(space): +def get_preprocessor(space: gym.Space) -> type: """Returns an appropriate preprocessor class for the given space.""" legacy_patch_shapes(space) @@ -310,7 +311,7 @@ def get_preprocessor(space): return preprocessor -def legacy_patch_shapes(space): +def legacy_patch_shapes(space: gym.Space) -> List[int]: """Assigns shapes to spaces that don't have shapes. This is only needed for older gym versions that don't set shapes properly diff --git a/rllib/models/repeated_values.py b/rllib/models/repeated_values.py index c042aaf36..d7e592cc0 100644 --- a/rllib/models/repeated_values.py +++ b/rllib/models/repeated_values.py @@ -1,7 +1,7 @@ from typing import List from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.framework import TensorType +from ray.rllib.utils.framework import TensorType, TensorStructType @PublicAPI @@ -48,7 +48,7 @@ class RepeatedValues: self.max_len = max_len self._unbatched_repr = None - def unbatch_all(self): + def unbatch_all(self) -> List[List[TensorType]]: """Unbatch both the repeat and batch dimensions into Python lists. This is only supported in PyTorch / TF eager mode. @@ -64,7 +64,7 @@ class RepeatedValues: >>> print(max(len(x) for x in items) <= N) True >>> print(items) - ... [, ..., ], + ... [[, ..., ], ... ... ... [, ], ... ... @@ -96,7 +96,7 @@ class RepeatedValues: return self._unbatched_repr - def unbatch_repeat_dim(self): + def unbatch_repeat_dim(self) -> List[TensorType]: """Unbatches the repeat dimension (the one `max_len` in size). This removes the repeat dimension. The result will be a Python list of @@ -120,7 +120,7 @@ class RepeatedValues: return repr(self) -def _get_batch_dim_helper(v): +def _get_batch_dim_helper(v: TensorStructType) -> int: """Tries to find the batch dimension size of v, or None.""" if isinstance(v, dict): for u in v.values(): @@ -136,7 +136,7 @@ def _get_batch_dim_helper(v): return B -def _unbatch_helper(v, max_len): +def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType: """Recursively unpacks the repeat dimension (max_len).""" if isinstance(v, dict): return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()} @@ -152,7 +152,8 @@ def _unbatch_helper(v, max_len): return [v[:, i, ...] for i in range(max_len)] -def _batch_index_helper(v, i, j): +def _batch_index_helper(v: TensorStructType, i: int, + j: int) -> TensorStructType: """Selects the item at the ith batch index and jth repetition.""" if isinstance(v, dict): return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()} diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index d906c83a4..94235e2ec 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -4,11 +4,13 @@ import functools import tree from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \ SMALL_NUMBER from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.types import TensorType, List tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() @@ -18,14 +20,14 @@ tfp = try_import_tfp() class TFActionDistribution(ActionDistribution): """TF-specific extensions for building action distributions.""" - @DeveloperAPI - def __init__(self, inputs, model): + @override(ActionDistribution) + def __init__(self, inputs: List[TensorType], model: ModelV2): super().__init__(inputs, model) self.sample_op = self._build_sample_op() self.sampled_action_logp_op = self.logp(self.sample_op) @DeveloperAPI - def _build_sample_op(self): + def _build_sample_op(self) -> TensorType: """Implement this instead of sample(), to enable op reuse. This is needed since the sample op is non-deterministic and is shared @@ -34,12 +36,12 @@ class TFActionDistribution(ActionDistribution): raise NotImplementedError @override(ActionDistribution) - def sample(self): + def sample(self) -> TensorType: """Draw a sample from the action distribution.""" return self.sample_op @override(ActionDistribution) - def sampled_action_logp(self): + def sampled_action_logp(self) -> TensorType: """Returns the log probability of the sampled action.""" return self.sampled_action_logp_op @@ -242,9 +244,8 @@ class DiagGaussian(TFActionDistribution): assert isinstance(other, DiagGaussian) return tf.reduce_sum( other.log_std - self.log_std + - (tf.math.square(self.std) + - tf.math.square(self.mean - other.mean)) / - (2.0 * tf.math.square(other.std)) - 0.5, + (tf.math.square(self.std) + tf.math.square(self.mean - other.mean)) + / (2.0 * tf.math.square(other.std)) - 0.5, axis=1) @override(ActionDistribution) diff --git a/rllib/models/tf/tf_modelv2.py b/rllib/models/tf/tf_modelv2.py index 94565286f..1d5408f13 100644 --- a/rllib/models/tf/tf_modelv2.py +++ b/rllib/models/tf/tf_modelv2.py @@ -1,6 +1,11 @@ +import contextlib +import gym +from typing import List + from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.types import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() @@ -12,8 +17,9 @@ class TFModelV2(ModelV2): Note that this class by itself is not a valid model unless you implement forward() in a subclass.""" - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str): """Initialize a TFModelV2. Here is an example implementation for a subclass @@ -44,31 +50,31 @@ class TFModelV2(ModelV2): else: self.graph = tf1.get_default_graph() - def context(self): + def context(self) -> contextlib.AbstractContextManager: """Returns a contextmanager for the current TF graph.""" if self.graph: return self.graph.as_default() else: return ModelV2.context(self) - def update_ops(self): + def update_ops(self) -> List[TensorType]: """Return the list of update ops for this model. For example, this should include any BatchNorm update ops.""" return [] - def register_variables(self, variables): + def register_variables(self, variables: List[TensorType]) -> None: """Register the given list of variables with this model.""" self.var_list.extend(variables) @override(ModelV2) - def variables(self, as_dict=False): + def variables(self, as_dict: bool = False) -> List[TensorType]: if as_dict: return {v.name: v for v in self.var_list} return list(self.var_list) @override(ModelV2) - def trainable_variables(self, as_dict=False): + def trainable_variables(self, as_dict: bool = False) -> List[TensorType]: if as_dict: return { k: v diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 9605f6e3c..9b7f70cfe 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -4,12 +4,14 @@ import numpy as np import tree from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \ MAX_LOG_NN_OUTPUT from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.torch_ops import atanh +from ray.rllib.utils.types import TensorType, List torch, nn = try_import_torch() @@ -18,7 +20,7 @@ class TorchDistributionWrapper(ActionDistribution): """Wrapper class for torch.distributions.""" @override(ActionDistribution) - def __init__(self, inputs, model): + def __init__(self, inputs: List[TensorType], model: ModelV2): if not isinstance(inputs, torch.Tensor): inputs = torch.Tensor(inputs) super().__init__(inputs, model) @@ -26,24 +28,24 @@ class TorchDistributionWrapper(ActionDistribution): self.last_sample = None @override(ActionDistribution) - def logp(self, actions): + def logp(self, actions: TensorType) -> TensorType: return self.dist.log_prob(actions) @override(ActionDistribution) - def entropy(self): + def entropy(self) -> TensorType: return self.dist.entropy() @override(ActionDistribution) - def kl(self, other): + def kl(self, other: ActionDistribution) -> TensorType: return torch.distributions.kl.kl_divergence(self.dist, other.dist) @override(ActionDistribution) - def sample(self): + def sample(self) -> TensorType: self.last_sample = self.dist.sample() return self.last_sample @override(ActionDistribution) - def sampled_action_logp(self): + def sampled_action_logp(self) -> TensorType: assert self.last_sample is not None return self.logp(self.last_sample) diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index cfbe48ad3..393d33ee4 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -1,6 +1,10 @@ +import gym +from typing import List + from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.types import ModelConfigDict, TensorType _, nn = try_import_torch() @@ -12,8 +16,9 @@ class TorchModelV2(ModelV2): Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.""" - def __init__(self, obs_space, action_space, num_outputs, model_config, - name): + def __init__(self, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, num_outputs: int, + model_config: ModelConfigDict, name: str): """Initialize a TorchModelV2. Here is an example implementation for a subclass @@ -42,13 +47,13 @@ class TorchModelV2(ModelV2): framework="torch") @override(ModelV2) - def variables(self, as_dict=False): + def variables(self, as_dict: bool = False) -> List[TensorType]: if as_dict: return self.state_dict() return list(self.parameters()) @override(ModelV2) - def trainable_variables(self, as_dict=False): + def trainable_variables(self, as_dict: bool = False) -> List[TensorType]: if as_dict: return { k: v diff --git a/rllib/policy/trajectory_view.py b/rllib/policy/trajectory_view.py index f1c8d4e22..9a45b12a3 100644 --- a/rllib/policy/trajectory_view.py +++ b/rllib/policy/trajectory_view.py @@ -1,8 +1,11 @@ import numpy as np -from typing import Dict, Optional +from typing import Dict, Optional, List, TYPE_CHECKING from ray.rllib.utils.types import TensorType +if TYPE_CHECKING: + from ray.rllib.models import ModelV2 + class ViewRequirement: """Single view requirement (for one column in a ModelV2 input_dict). @@ -57,10 +60,9 @@ class ViewRequirement: # "time_major", -def get_trajectory_view( - model, - trajectories, - is_training: bool = False) -> Dict[str, TensorType]: +def get_trajectory_view(model: "ModelV2", + trajectories: List["Trajectory"], + is_training: bool = False) -> Dict[str, TensorType]: """Returns an input_dict for a Model's forward pass given some data. Args: