[rllib] Type annotations for model classes (#9646)

This commit is contained in:
Eric Liang 2020-07-24 12:01:46 -07:00 committed by GitHub
parent 03709d67cb
commit 590943a499
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 153 additions and 120 deletions

View file

@ -44,7 +44,7 @@ MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# __sphinx_doc_begin__
COMMON_CONFIG = {
COMMON_CONFIG: TrainerConfigDict = {
# === Settings for Rollout Worker processes ===
# Number of rollout worker actors to create for parallel sampling. Setting
# this to 0 will force rollouts to be done in the trainer actor.

View file

@ -1,4 +1,9 @@
import numpy as np
import gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.types import TensorType, List, Union, ModelConfigDict
@DeveloperAPI
@ -11,7 +16,7 @@ class ActionDistribution:
"""
@DeveloperAPI
def __init__(self, inputs, model):
def __init__(self, inputs: List[TensorType], model: ModelV2):
"""Initialize the action dist.
Arguments:
@ -25,12 +30,12 @@ class ActionDistribution:
self.model = model
@DeveloperAPI
def sample(self):
def sample(self) -> TensorType:
"""Draw a sample from the action distribution."""
raise NotImplementedError
@DeveloperAPI
def deterministic_sample(self):
def deterministic_sample(self) -> TensorType:
"""
Get the deterministic "sampling" output from the distribution.
This is usually the max likelihood output, i.e. mean for Normal, argmax
@ -39,26 +44,26 @@ class ActionDistribution:
raise NotImplementedError
@DeveloperAPI
def sampled_action_logp(self):
def sampled_action_logp(self) -> TensorType:
"""Returns the log probability of the last sampled action."""
raise NotImplementedError
@DeveloperAPI
def logp(self, x):
def logp(self, x: TensorType) -> TensorType:
"""The log-likelihood of the action distribution."""
raise NotImplementedError
@DeveloperAPI
def kl(self, other):
def kl(self, other: "ActionDistribution") -> TensorType:
"""The KL-divergence between two action distributions."""
raise NotImplementedError
@DeveloperAPI
def entropy(self):
def entropy(self) -> TensorType:
"""The entropy of the action distribution."""
raise NotImplementedError
def multi_kl(self, other):
def multi_kl(self, other: "ActionDistribution") -> TensorType:
"""The KL-divergence between two action distributions.
This differs from kl() in that it can return an array for
@ -66,7 +71,7 @@ class ActionDistribution:
"""
return self.kl(other)
def multi_entropy(self):
def multi_entropy(self) -> TensorType:
"""The entropy of the action distribution.
This differs from entropy() in that it can return an array for
@ -76,7 +81,9 @@ class ActionDistribution:
@DeveloperAPI
@staticmethod
def required_model_output_shape(action_space, model_config):
def required_model_output_shape(
action_space: gym.Space,
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
"""Returns the required shape of an input parameter tensor for a
particular action space and an optional dict of distribution-specific
options.

View file

@ -3,12 +3,13 @@ import gym
import logging
import numpy as np
import tree
from typing import List
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
RLLIB_ACTION_DIST, _global_registry
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.lstm_v1 import LSTM
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
@ -26,6 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.types import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
MODEL_DEFAULTS = {
MODEL_DEFAULTS: ModelConfigDict = {
# === Built-in options ===
# Filter config. List of [out_channels, kernel, stride] for each filter
"conv_filters": None,
@ -114,11 +116,11 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_action_dist(action_space,
config,
dist_type=None,
framework="tf",
**kwargs):
def get_action_dist(action_space: gym.Space,
config: ModelConfigDict,
dist_type: str = None,
framework: str = "tf",
**kwargs) -> (type, int):
"""Returns a distribution class and size for the given action space.
Args:
@ -209,7 +211,7 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_action_shape(action_space):
def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]):
"""Returns action tensor dtype and shape for the action space.
Args:
@ -243,7 +245,8 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_action_placeholder(action_space, name="action"):
def get_action_placeholder(action_space: gym.Space,
name: str = "action") -> TensorType:
"""Returns an action placeholder consistent with the action space
Args:
@ -260,15 +263,15 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_model_v2(obs_space,
action_space,
num_outputs,
model_config,
framework="tf",
name="default_model",
model_interface=None,
default_model=None,
**model_kwargs):
def get_model_v2(obs_space: gym.Space,
action_space: gym.Space,
num_outputs: int,
model_config: ModelConfigDict,
framework: str = "tf",
name: str = "default_model",
model_interface: type = None,
default_model: type = None,
**model_kwargs) -> ModelV2:
"""Returns a suitable model compatible with given spaces and output.
Args:
@ -420,7 +423,7 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_preprocessor(env, options=None):
def get_preprocessor(env: gym.Env, options: dict = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given env.
This is a wrapper for get_preprocessor_for_space().
@ -431,7 +434,8 @@ class ModelCatalog:
@staticmethod
@DeveloperAPI
def get_preprocessor_for_space(observation_space, options=None):
def get_preprocessor_for_space(observation_space: gym.Space,
options: dict = None) -> Preprocessor:
"""Returns a suitable preprocessor for the given observation space.
Args:
@ -469,7 +473,8 @@ class ModelCatalog:
@staticmethod
@PublicAPI
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
def register_custom_preprocessor(preprocessor_name: str,
preprocessor_class: type) -> None:
"""Register a custom preprocessor class by name.
The preprocessor can be later used by specifying
@ -484,7 +489,7 @@ class ModelCatalog:
@staticmethod
@PublicAPI
def register_custom_model(model_name, model_class):
def register_custom_model(model_name: str, model_class: type) -> None:
"""Register a custom model class by name.
The model can be later used by specifying {"custom_model": model_name}
@ -498,7 +503,8 @@ class ModelCatalog:
@staticmethod
@PublicAPI
def register_custom_action_dist(action_dist_name, action_dist_class):
def register_custom_action_dist(action_dist_name: str,
action_dist_class: type) -> None:
"""Register a custom action distribution class by name.
The model can be later used by specifying
@ -512,7 +518,7 @@ class ModelCatalog:
action_dist_class)
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
assert issubclass(model_cls, ModelV2), model_cls
if not model_interface or issubclass(model_cls, model_interface):
@ -608,10 +614,3 @@ class ModelCatalog:
return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@staticmethod
def get_torch_model(obs_space,
num_outputs,
options=None,
default_model_cls=None):
raise DeprecationWarning("Please use get_model_v2() instead.")

View file

@ -1,6 +1,8 @@
from collections import OrderedDict
import contextlib
import gym
from typing import Dict
import numpy as np
from typing import Dict, List, Any, Union
from ray.rllib.models.preprocessors import get_preprocessor, \
RepeatedValuesPreprocessor
@ -11,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.utils.types import ModelConfigDict
from ray.rllib.utils.types import ModelConfigDict, TensorStructType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -29,13 +31,9 @@ class ModelV2:
value_function() -> V(s)
"""
def __init__(self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
framework: str):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str, framework: str):
"""Initializes a ModelV2 object.
This method should create any variables used by the model.
@ -62,7 +60,7 @@ class ModelV2:
self._last_output = None
@PublicAPI
def get_initial_state(self):
def get_initial_state(self) -> List[np.ndarray]:
"""Get the initial recurrent state values for the model.
Returns:
@ -79,7 +77,9 @@ class ModelV2:
return []
@PublicAPI
def forward(self, input_dict, state, seq_lens):
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by
@ -103,7 +103,7 @@ class ModelV2:
Returns:
(outputs, state): The model output tensor of size
[BATCH, num_outputs]
[BATCH, num_outputs], and the new RNN state.
Examples:
>>> def forward(self, input_dict, state, seq_lens):
@ -114,7 +114,7 @@ class ModelV2:
raise NotImplementedError
@PublicAPI
def value_function(self):
def value_function(self) -> TensorType:
"""Returns the value function output for the most recent forward pass.
Note that a `forward` call has to be performed first, before this
@ -127,7 +127,8 @@ class ModelV2:
raise NotImplementedError
@PublicAPI
def custom_loss(self, policy_loss, loss_inputs):
def custom_loss(self, policy_loss: TensorType,
loss_inputs: Dict[str, TensorType]) -> TensorType:
"""Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining
@ -149,7 +150,7 @@ class ModelV2:
return policy_loss
@PublicAPI
def metrics(self):
def metrics(self) -> Dict[str, TensorType]:
"""Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e.,
@ -164,7 +165,11 @@ class ModelV2:
"""
return {}
def __call__(self, input_dict, state=None, seq_lens=None):
def __call__(
self,
input_dict: Dict[str, TensorType],
state: List[Any] = None,
seq_lens: TensorType = None) -> (TensorType, List[TensorType]):
"""Call the model with the given input tensors and state.
This is the method used by RLlib to execute the forward pass. It calls
@ -218,7 +223,8 @@ class ModelV2:
return outputs, state
@PublicAPI
def from_batch(self, train_batch, is_training=True):
def from_batch(self, train_batch: SampleBatch,
is_training: bool = True) -> (TensorType, List[TensorType]):
"""Convenience function that calls this model with a tensor batch.
All this does is unpack the tensor batch to call this model with the
@ -241,8 +247,7 @@ class ModelV2:
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
def get_view_requirements(
self,
is_training: bool = False) -> Dict[str, ViewRequirement]:
self, is_training: bool = False) -> Dict[str, ViewRequirement]:
"""Returns a list of ViewRequirements for this Model (or None).
Note: This is an experimental API method.
@ -266,13 +271,13 @@ class ModelV2:
# Single requirement: Pass current obs as input.
return {
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
SampleBatch.PREV_ACTIONS:
ViewRequirement(SampleBatch.ACTIONS, timesteps=-1),
SampleBatch.PREV_REWARDS:
ViewRequirement(SampleBatch.REWARDS, timesteps=-1),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, timesteps=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, timesteps=-1),
}
def import_from_h5(self, h5_file):
def import_from_h5(self, h5_file: str) -> None:
"""Imports weights from an h5 file.
Args:
@ -287,17 +292,18 @@ class ModelV2:
raise NotImplementedError
@PublicAPI
def last_output(self):
def last_output(self) -> TensorType:
"""Returns the last output returned from calling the model."""
return self._last_output
@PublicAPI
def context(self):
def context(self) -> contextlib.AbstractContextManager:
"""Returns a contextmanager for the current forward pass."""
return NullContextManager()
@PublicAPI
def variables(self, as_dict=False):
def variables(self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Returns the list (or a dict) of variables for this model.
Args:
@ -311,7 +317,9 @@ class ModelV2:
raise NotImplementedError
@PublicAPI
def trainable_variables(self, as_dict=False):
def trainable_variables(
self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
"""Returns the list of trainable variables for this model.
Args:
@ -340,7 +348,7 @@ class NullContextManager:
@DeveloperAPI
def flatten(obs, framework):
def flatten(obs: TensorType, framework: str) -> TensorType:
"""Flatten the given tensor."""
if framework in ["tf2", "tf", "tfe"]:
return tf1.keras.layers.Flatten()(obs)
@ -354,7 +362,7 @@ def flatten(obs, framework):
@DeveloperAPI
def restore_original_dimensions(obs: TensorType,
obs_space: gym.spaces.Space,
tensorlib=tf):
tensorlib: Any = tf) -> TensorStructType:
"""Unpacks Dict and Tuple space observations into their original form.
This is needed since we flatten Dict and Tuple observations in transit
@ -388,7 +396,8 @@ def restore_original_dimensions(obs: TensorType,
_cache = {}
def _unpack_obs(obs, space, tensorlib=tf):
def _unpack_obs(obs: TensorType, space: gym.Space,
tensorlib: Any = tf) -> TensorStructType:
"""Unpack a flattened Dict or Tuple observation array/tensor.
Args:

View file

@ -3,6 +3,7 @@ import cv2
import logging
import numpy as np
import gym
from typing import Any, List
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.spaces.repeated import Repeated
@ -19,11 +20,11 @@ class Preprocessor:
"""Defines an abstract observation preprocessor function.
Attributes:
shape (obj): Shape of the preprocessed output.
shape (List[int]): Shape of the preprocessed output.
"""
@PublicAPI
def __init__(self, obs_space, options=None):
def __init__(self, obs_space: gym.Space, options: dict = None):
legacy_patch_shapes(obs_space)
self._obs_space = obs_space
if not options:
@ -36,20 +37,20 @@ class Preprocessor:
self._i = 0
@PublicAPI
def _init_shape(self, obs_space, options):
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
"""Returns the shape after preprocessing."""
raise NotImplementedError
@PublicAPI
def transform(self, observation):
def transform(self, observation: Any) -> np.ndarray:
"""Returns the preprocessed observation."""
raise NotImplementedError
def write(self, observation, array, offset):
def write(self, observation: Any, array: np.ndarray, offset: int) -> None:
"""Alternative to transform for more efficient flattening."""
array[offset:offset + self._size] = self.transform(observation)
def check_shape(self, observation):
def check_shape(self, observation: Any) -> None:
"""Checks the shape of the given observation."""
if self._i % VALIDATION_INTERVAL == 0:
if type(observation) is list and isinstance(
@ -69,12 +70,12 @@ class Preprocessor:
@property
@PublicAPI
def size(self):
def size(self) -> int:
return self._size
@property
@PublicAPI
def observation_space(self):
def observation_space(self) -> gym.Space:
obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32)
# Stash the unwrapped space so that we can unwrap dict and tuple spaces
# automatically in model.py
@ -286,7 +287,7 @@ class RepeatedValuesPreprocessor(Preprocessor):
@PublicAPI
def get_preprocessor(space):
def get_preprocessor(space: gym.Space) -> type:
"""Returns an appropriate preprocessor class for the given space."""
legacy_patch_shapes(space)
@ -310,7 +311,7 @@ def get_preprocessor(space):
return preprocessor
def legacy_patch_shapes(space):
def legacy_patch_shapes(space: gym.Space) -> List[int]:
"""Assigns shapes to spaces that don't have shapes.
This is only needed for older gym versions that don't set shapes properly

View file

@ -1,7 +1,7 @@
from typing import List
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.framework import TensorType
from ray.rllib.utils.framework import TensorType, TensorStructType
@PublicAPI
@ -48,7 +48,7 @@ class RepeatedValues:
self.max_len = max_len
self._unbatched_repr = None
def unbatch_all(self):
def unbatch_all(self) -> List[List[TensorType]]:
"""Unbatch both the repeat and batch dimensions into Python lists.
This is only supported in PyTorch / TF eager mode.
@ -64,7 +64,7 @@ class RepeatedValues:
>>> print(max(len(x) for x in items) <= N)
True
>>> print(items)
... [<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
... ...
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
... ...
@ -96,7 +96,7 @@ class RepeatedValues:
return self._unbatched_repr
def unbatch_repeat_dim(self):
def unbatch_repeat_dim(self) -> List[TensorType]:
"""Unbatches the repeat dimension (the one `max_len` in size).
This removes the repeat dimension. The result will be a Python list of
@ -120,7 +120,7 @@ class RepeatedValues:
return repr(self)
def _get_batch_dim_helper(v):
def _get_batch_dim_helper(v: TensorStructType) -> int:
"""Tries to find the batch dimension size of v, or None."""
if isinstance(v, dict):
for u in v.values():
@ -136,7 +136,7 @@ def _get_batch_dim_helper(v):
return B
def _unbatch_helper(v, max_len):
def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
"""Recursively unpacks the repeat dimension (max_len)."""
if isinstance(v, dict):
return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
@ -152,7 +152,8 @@ def _unbatch_helper(v, max_len):
return [v[:, i, ...] for i in range(max_len)]
def _batch_index_helper(v, i, j):
def _batch_index_helper(v: TensorStructType, i: int,
j: int) -> TensorStructType:
"""Selects the item at the ith batch index and jth repetition."""
if isinstance(v, dict):
return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}

View file

@ -4,11 +4,13 @@ import functools
import tree
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
SMALL_NUMBER
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.types import TensorType, List
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
@ -18,14 +20,14 @@ tfp = try_import_tfp()
class TFActionDistribution(ActionDistribution):
"""TF-specific extensions for building action distributions."""
@DeveloperAPI
def __init__(self, inputs, model):
@override(ActionDistribution)
def __init__(self, inputs: List[TensorType], model: ModelV2):
super().__init__(inputs, model)
self.sample_op = self._build_sample_op()
self.sampled_action_logp_op = self.logp(self.sample_op)
@DeveloperAPI
def _build_sample_op(self):
def _build_sample_op(self) -> TensorType:
"""Implement this instead of sample(), to enable op reuse.
This is needed since the sample op is non-deterministic and is shared
@ -34,12 +36,12 @@ class TFActionDistribution(ActionDistribution):
raise NotImplementedError
@override(ActionDistribution)
def sample(self):
def sample(self) -> TensorType:
"""Draw a sample from the action distribution."""
return self.sample_op
@override(ActionDistribution)
def sampled_action_logp(self):
def sampled_action_logp(self) -> TensorType:
"""Returns the log probability of the sampled action."""
return self.sampled_action_logp_op
@ -242,9 +244,8 @@ class DiagGaussian(TFActionDistribution):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(
other.log_std - self.log_std +
(tf.math.square(self.std) +
tf.math.square(self.mean - other.mean)) /
(2.0 * tf.math.square(other.std)) - 0.5,
(tf.math.square(self.std) + tf.math.square(self.mean - other.mean))
/ (2.0 * tf.math.square(other.std)) - 0.5,
axis=1)
@override(ActionDistribution)

View file

@ -1,6 +1,11 @@
import contextlib
import gym
from typing import List
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.types import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@ -12,8 +17,9 @@ class TFModelV2(ModelV2):
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
"""Initialize a TFModelV2.
Here is an example implementation for a subclass
@ -44,31 +50,31 @@ class TFModelV2(ModelV2):
else:
self.graph = tf1.get_default_graph()
def context(self):
def context(self) -> contextlib.AbstractContextManager:
"""Returns a contextmanager for the current TF graph."""
if self.graph:
return self.graph.as_default()
else:
return ModelV2.context(self)
def update_ops(self):
def update_ops(self) -> List[TensorType]:
"""Return the list of update ops for this model.
For example, this should include any BatchNorm update ops."""
return []
def register_variables(self, variables):
def register_variables(self, variables: List[TensorType]) -> None:
"""Register the given list of variables with this model."""
self.var_list.extend(variables)
@override(ModelV2)
def variables(self, as_dict=False):
def variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return {v.name: v for v in self.var_list}
return list(self.var_list)
@override(ModelV2)
def trainable_variables(self, as_dict=False):
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return {
k: v

View file

@ -4,12 +4,14 @@ import numpy as np
import tree
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
MAX_LOG_NN_OUTPUT
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.torch_ops import atanh
from ray.rllib.utils.types import TensorType, List
torch, nn = try_import_torch()
@ -18,7 +20,7 @@ class TorchDistributionWrapper(ActionDistribution):
"""Wrapper class for torch.distributions."""
@override(ActionDistribution)
def __init__(self, inputs, model):
def __init__(self, inputs: List[TensorType], model: ModelV2):
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs)
super().__init__(inputs, model)
@ -26,24 +28,24 @@ class TorchDistributionWrapper(ActionDistribution):
self.last_sample = None
@override(ActionDistribution)
def logp(self, actions):
def logp(self, actions: TensorType) -> TensorType:
return self.dist.log_prob(actions)
@override(ActionDistribution)
def entropy(self):
def entropy(self) -> TensorType:
return self.dist.entropy()
@override(ActionDistribution)
def kl(self, other):
def kl(self, other: ActionDistribution) -> TensorType:
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
@override(ActionDistribution)
def sample(self):
def sample(self) -> TensorType:
self.last_sample = self.dist.sample()
return self.last_sample
@override(ActionDistribution)
def sampled_action_logp(self):
def sampled_action_logp(self) -> TensorType:
assert self.last_sample is not None
return self.logp(self.last_sample)

View file

@ -1,6 +1,10 @@
import gym
from typing import List
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.types import ModelConfigDict, TensorType
_, nn = try_import_torch()
@ -12,8 +16,9 @@ class TorchModelV2(ModelV2):
Note that this class by itself is not a valid model unless you
inherit from nn.Module and implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
"""Initialize a TorchModelV2.
Here is an example implementation for a subclass
@ -42,13 +47,13 @@ class TorchModelV2(ModelV2):
framework="torch")
@override(ModelV2)
def variables(self, as_dict=False):
def variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return self.state_dict()
return list(self.parameters())
@override(ModelV2)
def trainable_variables(self, as_dict=False):
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return {
k: v

View file

@ -1,8 +1,11 @@
import numpy as np
from typing import Dict, Optional
from typing import Dict, Optional, List, TYPE_CHECKING
from ray.rllib.utils.types import TensorType
if TYPE_CHECKING:
from ray.rllib.models import ModelV2
class ViewRequirement:
"""Single view requirement (for one column in a ModelV2 input_dict).
@ -57,9 +60,8 @@ class ViewRequirement:
# "time_major",
def get_trajectory_view(
model,
trajectories,
def get_trajectory_view(model: "ModelV2",
trajectories: List["Trajectory"],
is_training: bool = False) -> Dict[str, TensorType]:
"""Returns an input_dict for a Model's forward pass given some data.