mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Type annotations for model classes (#9646)
This commit is contained in:
parent
03709d67cb
commit
590943a499
11 changed files with 153 additions and 120 deletions
|
@ -44,7 +44,7 @@ MAX_WORKER_FAILURE_RETRIES = 3
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
# __sphinx_doc_begin__
|
# __sphinx_doc_begin__
|
||||||
COMMON_CONFIG = {
|
COMMON_CONFIG: TrainerConfigDict = {
|
||||||
# === Settings for Rollout Worker processes ===
|
# === Settings for Rollout Worker processes ===
|
||||||
# Number of rollout worker actors to create for parallel sampling. Setting
|
# Number of rollout worker actors to create for parallel sampling. Setting
|
||||||
# this to 0 will force rollouts to be done in the trainer actor.
|
# this to 0 will force rollouts to be done in the trainer actor.
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
|
import numpy as np
|
||||||
|
import gym
|
||||||
|
|
||||||
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.types import TensorType, List, Union, ModelConfigDict
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
|
@ -11,7 +16,7 @@ class ActionDistribution:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def __init__(self, inputs, model):
|
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
||||||
"""Initialize the action dist.
|
"""Initialize the action dist.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -25,12 +30,12 @@ class ActionDistribution:
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def sample(self):
|
def sample(self) -> TensorType:
|
||||||
"""Draw a sample from the action distribution."""
|
"""Draw a sample from the action distribution."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def deterministic_sample(self):
|
def deterministic_sample(self) -> TensorType:
|
||||||
"""
|
"""
|
||||||
Get the deterministic "sampling" output from the distribution.
|
Get the deterministic "sampling" output from the distribution.
|
||||||
This is usually the max likelihood output, i.e. mean for Normal, argmax
|
This is usually the max likelihood output, i.e. mean for Normal, argmax
|
||||||
|
@ -39,26 +44,26 @@ class ActionDistribution:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def sampled_action_logp(self):
|
def sampled_action_logp(self) -> TensorType:
|
||||||
"""Returns the log probability of the last sampled action."""
|
"""Returns the log probability of the last sampled action."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def logp(self, x):
|
def logp(self, x: TensorType) -> TensorType:
|
||||||
"""The log-likelihood of the action distribution."""
|
"""The log-likelihood of the action distribution."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def kl(self, other):
|
def kl(self, other: "ActionDistribution") -> TensorType:
|
||||||
"""The KL-divergence between two action distributions."""
|
"""The KL-divergence between two action distributions."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def entropy(self):
|
def entropy(self) -> TensorType:
|
||||||
"""The entropy of the action distribution."""
|
"""The entropy of the action distribution."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def multi_kl(self, other):
|
def multi_kl(self, other: "ActionDistribution") -> TensorType:
|
||||||
"""The KL-divergence between two action distributions.
|
"""The KL-divergence between two action distributions.
|
||||||
|
|
||||||
This differs from kl() in that it can return an array for
|
This differs from kl() in that it can return an array for
|
||||||
|
@ -66,7 +71,7 @@ class ActionDistribution:
|
||||||
"""
|
"""
|
||||||
return self.kl(other)
|
return self.kl(other)
|
||||||
|
|
||||||
def multi_entropy(self):
|
def multi_entropy(self) -> TensorType:
|
||||||
"""The entropy of the action distribution.
|
"""The entropy of the action distribution.
|
||||||
|
|
||||||
This differs from entropy() in that it can return an array for
|
This differs from entropy() in that it can return an array for
|
||||||
|
@ -76,7 +81,9 @@ class ActionDistribution:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def required_model_output_shape(action_space, model_config):
|
def required_model_output_shape(
|
||||||
|
action_space: gym.Space,
|
||||||
|
model_config: ModelConfigDict) -> Union[int, np.ndarray]:
|
||||||
"""Returns the required shape of an input parameter tensor for a
|
"""Returns the required shape of an input parameter tensor for a
|
||||||
particular action space and an optional dict of distribution-specific
|
particular action space and an optional dict of distribution-specific
|
||||||
options.
|
options.
|
||||||
|
|
|
@ -3,12 +3,13 @@ import gym
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tree
|
import tree
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||||
RLLIB_ACTION_DIST, _global_registry
|
RLLIB_ACTION_DIST, _global_registry
|
||||||
from ray.rllib.models.action_dist import ActionDistribution
|
from ray.rllib.models.action_dist import ActionDistribution
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.models.preprocessors import get_preprocessor
|
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
|
||||||
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
|
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
|
||||||
from ray.rllib.models.tf.lstm_v1 import LSTM
|
from ray.rllib.models.tf.lstm_v1 import LSTM
|
||||||
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
|
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
|
||||||
|
@ -26,6 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.spaces.simplex import Simplex
|
from ray.rllib.utils.spaces.simplex import Simplex
|
||||||
from ray.rllib.utils.spaces.space_utils import flatten_space
|
from ray.rllib.utils.spaces.space_utils import flatten_space
|
||||||
|
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
# __sphinx_doc_begin__
|
# __sphinx_doc_begin__
|
||||||
MODEL_DEFAULTS = {
|
MODEL_DEFAULTS: ModelConfigDict = {
|
||||||
# === Built-in options ===
|
# === Built-in options ===
|
||||||
# Filter config. List of [out_channels, kernel, stride] for each filter
|
# Filter config. List of [out_channels, kernel, stride] for each filter
|
||||||
"conv_filters": None,
|
"conv_filters": None,
|
||||||
|
@ -114,11 +116,11 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_action_dist(action_space,
|
def get_action_dist(action_space: gym.Space,
|
||||||
config,
|
config: ModelConfigDict,
|
||||||
dist_type=None,
|
dist_type: str = None,
|
||||||
framework="tf",
|
framework: str = "tf",
|
||||||
**kwargs):
|
**kwargs) -> (type, int):
|
||||||
"""Returns a distribution class and size for the given action space.
|
"""Returns a distribution class and size for the given action space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -209,7 +211,7 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_action_shape(action_space):
|
def get_action_shape(action_space: gym.Space) -> (np.dtype, List[int]):
|
||||||
"""Returns action tensor dtype and shape for the action space.
|
"""Returns action tensor dtype and shape for the action space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -243,7 +245,8 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_action_placeholder(action_space, name="action"):
|
def get_action_placeholder(action_space: gym.Space,
|
||||||
|
name: str = "action") -> TensorType:
|
||||||
"""Returns an action placeholder consistent with the action space
|
"""Returns an action placeholder consistent with the action space
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -260,15 +263,15 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_model_v2(obs_space,
|
def get_model_v2(obs_space: gym.Space,
|
||||||
action_space,
|
action_space: gym.Space,
|
||||||
num_outputs,
|
num_outputs: int,
|
||||||
model_config,
|
model_config: ModelConfigDict,
|
||||||
framework="tf",
|
framework: str = "tf",
|
||||||
name="default_model",
|
name: str = "default_model",
|
||||||
model_interface=None,
|
model_interface: type = None,
|
||||||
default_model=None,
|
default_model: type = None,
|
||||||
**model_kwargs):
|
**model_kwargs) -> ModelV2:
|
||||||
"""Returns a suitable model compatible with given spaces and output.
|
"""Returns a suitable model compatible with given spaces and output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -420,7 +423,7 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_preprocessor(env, options=None):
|
def get_preprocessor(env: gym.Env, options: dict = None) -> Preprocessor:
|
||||||
"""Returns a suitable preprocessor for the given env.
|
"""Returns a suitable preprocessor for the given env.
|
||||||
|
|
||||||
This is a wrapper for get_preprocessor_for_space().
|
This is a wrapper for get_preprocessor_for_space().
|
||||||
|
@ -431,7 +434,8 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_preprocessor_for_space(observation_space, options=None):
|
def get_preprocessor_for_space(observation_space: gym.Space,
|
||||||
|
options: dict = None) -> Preprocessor:
|
||||||
"""Returns a suitable preprocessor for the given observation space.
|
"""Returns a suitable preprocessor for the given observation space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -469,7 +473,8 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
|
def register_custom_preprocessor(preprocessor_name: str,
|
||||||
|
preprocessor_class: type) -> None:
|
||||||
"""Register a custom preprocessor class by name.
|
"""Register a custom preprocessor class by name.
|
||||||
|
|
||||||
The preprocessor can be later used by specifying
|
The preprocessor can be later used by specifying
|
||||||
|
@ -484,7 +489,7 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def register_custom_model(model_name, model_class):
|
def register_custom_model(model_name: str, model_class: type) -> None:
|
||||||
"""Register a custom model class by name.
|
"""Register a custom model class by name.
|
||||||
|
|
||||||
The model can be later used by specifying {"custom_model": model_name}
|
The model can be later used by specifying {"custom_model": model_name}
|
||||||
|
@ -498,7 +503,8 @@ class ModelCatalog:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def register_custom_action_dist(action_dist_name, action_dist_class):
|
def register_custom_action_dist(action_dist_name: str,
|
||||||
|
action_dist_class: type) -> None:
|
||||||
"""Register a custom action distribution class by name.
|
"""Register a custom action distribution class by name.
|
||||||
|
|
||||||
The model can be later used by specifying
|
The model can be later used by specifying
|
||||||
|
@ -512,7 +518,7 @@ class ModelCatalog:
|
||||||
action_dist_class)
|
action_dist_class)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _wrap_if_needed(model_cls, model_interface):
|
def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
|
||||||
assert issubclass(model_cls, ModelV2), model_cls
|
assert issubclass(model_cls, ModelV2), model_cls
|
||||||
|
|
||||||
if not model_interface or issubclass(model_cls, model_interface):
|
if not model_interface or issubclass(model_cls, model_interface):
|
||||||
|
@ -608,10 +614,3 @@ class ModelCatalog:
|
||||||
|
|
||||||
return FullyConnectedNetwork(input_dict, obs_space, action_space,
|
return FullyConnectedNetwork(input_dict, obs_space, action_space,
|
||||||
num_outputs, options)
|
num_outputs, options)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_torch_model(obs_space,
|
|
||||||
num_outputs,
|
|
||||||
options=None,
|
|
||||||
default_model_cls=None):
|
|
||||||
raise DeprecationWarning("Please use get_model_v2() instead.")
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import contextlib
|
||||||
import gym
|
import gym
|
||||||
from typing import Dict
|
import numpy as np
|
||||||
|
from typing import Dict, List, Any, Union
|
||||||
|
|
||||||
from ray.rllib.models.preprocessors import get_preprocessor, \
|
from ray.rllib.models.preprocessors import get_preprocessor, \
|
||||||
RepeatedValuesPreprocessor
|
RepeatedValuesPreprocessor
|
||||||
|
@ -11,7 +13,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||||
TensorType
|
TensorType
|
||||||
from ray.rllib.utils.spaces.repeated import Repeated
|
from ray.rllib.utils.spaces.repeated import Repeated
|
||||||
from ray.rllib.utils.types import ModelConfigDict
|
from ray.rllib.utils.types import ModelConfigDict, TensorStructType
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
torch, _ = try_import_torch()
|
torch, _ = try_import_torch()
|
||||||
|
@ -29,13 +31,9 @@ class ModelV2:
|
||||||
value_function() -> V(s)
|
value_function() -> V(s)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, obs_space: gym.spaces.Space,
|
||||||
obs_space: gym.spaces.Space,
|
action_space: gym.spaces.Space, num_outputs: int,
|
||||||
action_space: gym.spaces.Space,
|
model_config: ModelConfigDict, name: str, framework: str):
|
||||||
num_outputs: int,
|
|
||||||
model_config: ModelConfigDict,
|
|
||||||
name: str,
|
|
||||||
framework: str):
|
|
||||||
"""Initializes a ModelV2 object.
|
"""Initializes a ModelV2 object.
|
||||||
|
|
||||||
This method should create any variables used by the model.
|
This method should create any variables used by the model.
|
||||||
|
@ -62,7 +60,7 @@ class ModelV2:
|
||||||
self._last_output = None
|
self._last_output = None
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def get_initial_state(self):
|
def get_initial_state(self) -> List[np.ndarray]:
|
||||||
"""Get the initial recurrent state values for the model.
|
"""Get the initial recurrent state values for the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -79,7 +77,9 @@ class ModelV2:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def forward(self, input_dict, state, seq_lens):
|
def forward(self, input_dict: Dict[str, TensorType],
|
||||||
|
state: List[TensorType],
|
||||||
|
seq_lens: TensorType) -> (TensorType, List[TensorType]):
|
||||||
"""Call the model with the given input tensors and state.
|
"""Call the model with the given input tensors and state.
|
||||||
|
|
||||||
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
||||||
|
@ -103,7 +103,7 @@ class ModelV2:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(outputs, state): The model output tensor of size
|
(outputs, state): The model output tensor of size
|
||||||
[BATCH, num_outputs]
|
[BATCH, num_outputs], and the new RNN state.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> def forward(self, input_dict, state, seq_lens):
|
>>> def forward(self, input_dict, state, seq_lens):
|
||||||
|
@ -114,7 +114,7 @@ class ModelV2:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def value_function(self):
|
def value_function(self) -> TensorType:
|
||||||
"""Returns the value function output for the most recent forward pass.
|
"""Returns the value function output for the most recent forward pass.
|
||||||
|
|
||||||
Note that a `forward` call has to be performed first, before this
|
Note that a `forward` call has to be performed first, before this
|
||||||
|
@ -127,7 +127,8 @@ class ModelV2:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def custom_loss(self, policy_loss, loss_inputs):
|
def custom_loss(self, policy_loss: TensorType,
|
||||||
|
loss_inputs: Dict[str, TensorType]) -> TensorType:
|
||||||
"""Override to customize the loss function used to optimize this model.
|
"""Override to customize the loss function used to optimize this model.
|
||||||
|
|
||||||
This can be used to incorporate self-supervised losses (by defining
|
This can be used to incorporate self-supervised losses (by defining
|
||||||
|
@ -149,7 +150,7 @@ class ModelV2:
|
||||||
return policy_loss
|
return policy_loss
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def metrics(self):
|
def metrics(self) -> Dict[str, TensorType]:
|
||||||
"""Override to return custom metrics from your model.
|
"""Override to return custom metrics from your model.
|
||||||
|
|
||||||
The stats will be reported as part of the learner stats, i.e.,
|
The stats will be reported as part of the learner stats, i.e.,
|
||||||
|
@ -164,7 +165,11 @@ class ModelV2:
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def __call__(self, input_dict, state=None, seq_lens=None):
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_dict: Dict[str, TensorType],
|
||||||
|
state: List[Any] = None,
|
||||||
|
seq_lens: TensorType = None) -> (TensorType, List[TensorType]):
|
||||||
"""Call the model with the given input tensors and state.
|
"""Call the model with the given input tensors and state.
|
||||||
|
|
||||||
This is the method used by RLlib to execute the forward pass. It calls
|
This is the method used by RLlib to execute the forward pass. It calls
|
||||||
|
@ -218,7 +223,8 @@ class ModelV2:
|
||||||
return outputs, state
|
return outputs, state
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def from_batch(self, train_batch, is_training=True):
|
def from_batch(self, train_batch: SampleBatch,
|
||||||
|
is_training: bool = True) -> (TensorType, List[TensorType]):
|
||||||
"""Convenience function that calls this model with a tensor batch.
|
"""Convenience function that calls this model with a tensor batch.
|
||||||
|
|
||||||
All this does is unpack the tensor batch to call this model with the
|
All this does is unpack the tensor batch to call this model with the
|
||||||
|
@ -241,8 +247,7 @@ class ModelV2:
|
||||||
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
||||||
|
|
||||||
def get_view_requirements(
|
def get_view_requirements(
|
||||||
self,
|
self, is_training: bool = False) -> Dict[str, ViewRequirement]:
|
||||||
is_training: bool = False) -> Dict[str, ViewRequirement]:
|
|
||||||
"""Returns a list of ViewRequirements for this Model (or None).
|
"""Returns a list of ViewRequirements for this Model (or None).
|
||||||
|
|
||||||
Note: This is an experimental API method.
|
Note: This is an experimental API method.
|
||||||
|
@ -266,13 +271,13 @@ class ModelV2:
|
||||||
# Single requirement: Pass current obs as input.
|
# Single requirement: Pass current obs as input.
|
||||||
return {
|
return {
|
||||||
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
|
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
|
||||||
SampleBatch.PREV_ACTIONS:
|
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
||||||
ViewRequirement(SampleBatch.ACTIONS, timesteps=-1),
|
SampleBatch.ACTIONS, timesteps=-1),
|
||||||
SampleBatch.PREV_REWARDS:
|
SampleBatch.PREV_REWARDS: ViewRequirement(
|
||||||
ViewRequirement(SampleBatch.REWARDS, timesteps=-1),
|
SampleBatch.REWARDS, timesteps=-1),
|
||||||
}
|
}
|
||||||
|
|
||||||
def import_from_h5(self, h5_file):
|
def import_from_h5(self, h5_file: str) -> None:
|
||||||
"""Imports weights from an h5 file.
|
"""Imports weights from an h5 file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -287,17 +292,18 @@ class ModelV2:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def last_output(self):
|
def last_output(self) -> TensorType:
|
||||||
"""Returns the last output returned from calling the model."""
|
"""Returns the last output returned from calling the model."""
|
||||||
return self._last_output
|
return self._last_output
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def context(self):
|
def context(self) -> contextlib.AbstractContextManager:
|
||||||
"""Returns a contextmanager for the current forward pass."""
|
"""Returns a contextmanager for the current forward pass."""
|
||||||
return NullContextManager()
|
return NullContextManager()
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def variables(self, as_dict=False):
|
def variables(self, as_dict: bool = False
|
||||||
|
) -> Union[List[TensorType], Dict[str, TensorType]]:
|
||||||
"""Returns the list (or a dict) of variables for this model.
|
"""Returns the list (or a dict) of variables for this model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -311,7 +317,9 @@ class ModelV2:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def trainable_variables(self, as_dict=False):
|
def trainable_variables(
|
||||||
|
self, as_dict: bool = False
|
||||||
|
) -> Union[List[TensorType], Dict[str, TensorType]]:
|
||||||
"""Returns the list of trainable variables for this model.
|
"""Returns the list of trainable variables for this model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -340,7 +348,7 @@ class NullContextManager:
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def flatten(obs, framework):
|
def flatten(obs: TensorType, framework: str) -> TensorType:
|
||||||
"""Flatten the given tensor."""
|
"""Flatten the given tensor."""
|
||||||
if framework in ["tf2", "tf", "tfe"]:
|
if framework in ["tf2", "tf", "tfe"]:
|
||||||
return tf1.keras.layers.Flatten()(obs)
|
return tf1.keras.layers.Flatten()(obs)
|
||||||
|
@ -354,7 +362,7 @@ def flatten(obs, framework):
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def restore_original_dimensions(obs: TensorType,
|
def restore_original_dimensions(obs: TensorType,
|
||||||
obs_space: gym.spaces.Space,
|
obs_space: gym.spaces.Space,
|
||||||
tensorlib=tf):
|
tensorlib: Any = tf) -> TensorStructType:
|
||||||
"""Unpacks Dict and Tuple space observations into their original form.
|
"""Unpacks Dict and Tuple space observations into their original form.
|
||||||
|
|
||||||
This is needed since we flatten Dict and Tuple observations in transit
|
This is needed since we flatten Dict and Tuple observations in transit
|
||||||
|
@ -388,7 +396,8 @@ def restore_original_dimensions(obs: TensorType,
|
||||||
_cache = {}
|
_cache = {}
|
||||||
|
|
||||||
|
|
||||||
def _unpack_obs(obs, space, tensorlib=tf):
|
def _unpack_obs(obs: TensorType, space: gym.Space,
|
||||||
|
tensorlib: Any = tf) -> TensorStructType:
|
||||||
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -3,6 +3,7 @@ import cv2
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.spaces.repeated import Repeated
|
from ray.rllib.utils.spaces.repeated import Repeated
|
||||||
|
@ -19,11 +20,11 @@ class Preprocessor:
|
||||||
"""Defines an abstract observation preprocessor function.
|
"""Defines an abstract observation preprocessor function.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
shape (obj): Shape of the preprocessed output.
|
shape (List[int]): Shape of the preprocessed output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __init__(self, obs_space, options=None):
|
def __init__(self, obs_space: gym.Space, options: dict = None):
|
||||||
legacy_patch_shapes(obs_space)
|
legacy_patch_shapes(obs_space)
|
||||||
self._obs_space = obs_space
|
self._obs_space = obs_space
|
||||||
if not options:
|
if not options:
|
||||||
|
@ -36,20 +37,20 @@ class Preprocessor:
|
||||||
self._i = 0
|
self._i = 0
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def _init_shape(self, obs_space, options):
|
def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
|
||||||
"""Returns the shape after preprocessing."""
|
"""Returns the shape after preprocessing."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def transform(self, observation):
|
def transform(self, observation: Any) -> np.ndarray:
|
||||||
"""Returns the preprocessed observation."""
|
"""Returns the preprocessed observation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write(self, observation, array, offset):
|
def write(self, observation: Any, array: np.ndarray, offset: int) -> None:
|
||||||
"""Alternative to transform for more efficient flattening."""
|
"""Alternative to transform for more efficient flattening."""
|
||||||
array[offset:offset + self._size] = self.transform(observation)
|
array[offset:offset + self._size] = self.transform(observation)
|
||||||
|
|
||||||
def check_shape(self, observation):
|
def check_shape(self, observation: Any) -> None:
|
||||||
"""Checks the shape of the given observation."""
|
"""Checks the shape of the given observation."""
|
||||||
if self._i % VALIDATION_INTERVAL == 0:
|
if self._i % VALIDATION_INTERVAL == 0:
|
||||||
if type(observation) is list and isinstance(
|
if type(observation) is list and isinstance(
|
||||||
|
@ -69,12 +70,12 @@ class Preprocessor:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def size(self):
|
def size(self) -> int:
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def observation_space(self):
|
def observation_space(self) -> gym.Space:
|
||||||
obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32)
|
obs_space = gym.spaces.Box(-1., 1., self.shape, dtype=np.float32)
|
||||||
# Stash the unwrapped space so that we can unwrap dict and tuple spaces
|
# Stash the unwrapped space so that we can unwrap dict and tuple spaces
|
||||||
# automatically in model.py
|
# automatically in model.py
|
||||||
|
@ -286,7 +287,7 @@ class RepeatedValuesPreprocessor(Preprocessor):
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def get_preprocessor(space):
|
def get_preprocessor(space: gym.Space) -> type:
|
||||||
"""Returns an appropriate preprocessor class for the given space."""
|
"""Returns an appropriate preprocessor class for the given space."""
|
||||||
|
|
||||||
legacy_patch_shapes(space)
|
legacy_patch_shapes(space)
|
||||||
|
@ -310,7 +311,7 @@ def get_preprocessor(space):
|
||||||
return preprocessor
|
return preprocessor
|
||||||
|
|
||||||
|
|
||||||
def legacy_patch_shapes(space):
|
def legacy_patch_shapes(space: gym.Space) -> List[int]:
|
||||||
"""Assigns shapes to spaces that don't have shapes.
|
"""Assigns shapes to spaces that don't have shapes.
|
||||||
|
|
||||||
This is only needed for older gym versions that don't set shapes properly
|
This is only needed for older gym versions that don't set shapes properly
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from ray.rllib.utils.annotations import PublicAPI
|
from ray.rllib.utils.annotations import PublicAPI
|
||||||
from ray.rllib.utils.framework import TensorType
|
from ray.rllib.utils.framework import TensorType, TensorStructType
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
|
@ -48,7 +48,7 @@ class RepeatedValues:
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self._unbatched_repr = None
|
self._unbatched_repr = None
|
||||||
|
|
||||||
def unbatch_all(self):
|
def unbatch_all(self) -> List[List[TensorType]]:
|
||||||
"""Unbatch both the repeat and batch dimensions into Python lists.
|
"""Unbatch both the repeat and batch dimensions into Python lists.
|
||||||
|
|
||||||
This is only supported in PyTorch / TF eager mode.
|
This is only supported in PyTorch / TF eager mode.
|
||||||
|
@ -64,7 +64,7 @@ class RepeatedValues:
|
||||||
>>> print(max(len(x) for x in items) <= N)
|
>>> print(max(len(x) for x in items) <= N)
|
||||||
True
|
True
|
||||||
>>> print(items)
|
>>> print(items)
|
||||||
... [<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
|
... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
|
||||||
... ...
|
... ...
|
||||||
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
|
... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
|
||||||
... ...
|
... ...
|
||||||
|
@ -96,7 +96,7 @@ class RepeatedValues:
|
||||||
|
|
||||||
return self._unbatched_repr
|
return self._unbatched_repr
|
||||||
|
|
||||||
def unbatch_repeat_dim(self):
|
def unbatch_repeat_dim(self) -> List[TensorType]:
|
||||||
"""Unbatches the repeat dimension (the one `max_len` in size).
|
"""Unbatches the repeat dimension (the one `max_len` in size).
|
||||||
|
|
||||||
This removes the repeat dimension. The result will be a Python list of
|
This removes the repeat dimension. The result will be a Python list of
|
||||||
|
@ -120,7 +120,7 @@ class RepeatedValues:
|
||||||
return repr(self)
|
return repr(self)
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_dim_helper(v):
|
def _get_batch_dim_helper(v: TensorStructType) -> int:
|
||||||
"""Tries to find the batch dimension size of v, or None."""
|
"""Tries to find the batch dimension size of v, or None."""
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
for u in v.values():
|
for u in v.values():
|
||||||
|
@ -136,7 +136,7 @@ def _get_batch_dim_helper(v):
|
||||||
return B
|
return B
|
||||||
|
|
||||||
|
|
||||||
def _unbatch_helper(v, max_len):
|
def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
|
||||||
"""Recursively unpacks the repeat dimension (max_len)."""
|
"""Recursively unpacks the repeat dimension (max_len)."""
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
|
return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
|
||||||
|
@ -152,7 +152,8 @@ def _unbatch_helper(v, max_len):
|
||||||
return [v[:, i, ...] for i in range(max_len)]
|
return [v[:, i, ...] for i in range(max_len)]
|
||||||
|
|
||||||
|
|
||||||
def _batch_index_helper(v, i, j):
|
def _batch_index_helper(v: TensorStructType, i: int,
|
||||||
|
j: int) -> TensorStructType:
|
||||||
"""Selects the item at the ith batch index and jth repetition."""
|
"""Selects the item at the ith batch index and jth repetition."""
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}
|
return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}
|
||||||
|
|
|
@ -4,11 +4,13 @@ import functools
|
||||||
import tree
|
import tree
|
||||||
|
|
||||||
from ray.rllib.models.action_dist import ActionDistribution
|
from ray.rllib.models.action_dist import ActionDistribution
|
||||||
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
|
from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
|
||||||
SMALL_NUMBER
|
SMALL_NUMBER
|
||||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||||
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
||||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||||
|
from ray.rllib.utils.types import TensorType, List
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
tfp = try_import_tfp()
|
tfp = try_import_tfp()
|
||||||
|
@ -18,14 +20,14 @@ tfp = try_import_tfp()
|
||||||
class TFActionDistribution(ActionDistribution):
|
class TFActionDistribution(ActionDistribution):
|
||||||
"""TF-specific extensions for building action distributions."""
|
"""TF-specific extensions for building action distributions."""
|
||||||
|
|
||||||
@DeveloperAPI
|
@override(ActionDistribution)
|
||||||
def __init__(self, inputs, model):
|
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
||||||
super().__init__(inputs, model)
|
super().__init__(inputs, model)
|
||||||
self.sample_op = self._build_sample_op()
|
self.sample_op = self._build_sample_op()
|
||||||
self.sampled_action_logp_op = self.logp(self.sample_op)
|
self.sampled_action_logp_op = self.logp(self.sample_op)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def _build_sample_op(self):
|
def _build_sample_op(self) -> TensorType:
|
||||||
"""Implement this instead of sample(), to enable op reuse.
|
"""Implement this instead of sample(), to enable op reuse.
|
||||||
|
|
||||||
This is needed since the sample op is non-deterministic and is shared
|
This is needed since the sample op is non-deterministic and is shared
|
||||||
|
@ -34,12 +36,12 @@ class TFActionDistribution(ActionDistribution):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def sample(self):
|
def sample(self) -> TensorType:
|
||||||
"""Draw a sample from the action distribution."""
|
"""Draw a sample from the action distribution."""
|
||||||
return self.sample_op
|
return self.sample_op
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def sampled_action_logp(self):
|
def sampled_action_logp(self) -> TensorType:
|
||||||
"""Returns the log probability of the sampled action."""
|
"""Returns the log probability of the sampled action."""
|
||||||
return self.sampled_action_logp_op
|
return self.sampled_action_logp_op
|
||||||
|
|
||||||
|
@ -242,9 +244,8 @@ class DiagGaussian(TFActionDistribution):
|
||||||
assert isinstance(other, DiagGaussian)
|
assert isinstance(other, DiagGaussian)
|
||||||
return tf.reduce_sum(
|
return tf.reduce_sum(
|
||||||
other.log_std - self.log_std +
|
other.log_std - self.log_std +
|
||||||
(tf.math.square(self.std) +
|
(tf.math.square(self.std) + tf.math.square(self.mean - other.mean))
|
||||||
tf.math.square(self.mean - other.mean)) /
|
/ (2.0 * tf.math.square(other.std)) - 0.5,
|
||||||
(2.0 * tf.math.square(other.std)) - 0.5,
|
|
||||||
axis=1)
|
axis=1)
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
|
import contextlib
|
||||||
|
import gym
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
@ -12,8 +17,9 @@ class TFModelV2(ModelV2):
|
||||||
Note that this class by itself is not a valid model unless you
|
Note that this class by itself is not a valid model unless you
|
||||||
implement forward() in a subclass."""
|
implement forward() in a subclass."""
|
||||||
|
|
||||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
def __init__(self, obs_space: gym.spaces.Space,
|
||||||
name):
|
action_space: gym.spaces.Space, num_outputs: int,
|
||||||
|
model_config: ModelConfigDict, name: str):
|
||||||
"""Initialize a TFModelV2.
|
"""Initialize a TFModelV2.
|
||||||
|
|
||||||
Here is an example implementation for a subclass
|
Here is an example implementation for a subclass
|
||||||
|
@ -44,31 +50,31 @@ class TFModelV2(ModelV2):
|
||||||
else:
|
else:
|
||||||
self.graph = tf1.get_default_graph()
|
self.graph = tf1.get_default_graph()
|
||||||
|
|
||||||
def context(self):
|
def context(self) -> contextlib.AbstractContextManager:
|
||||||
"""Returns a contextmanager for the current TF graph."""
|
"""Returns a contextmanager for the current TF graph."""
|
||||||
if self.graph:
|
if self.graph:
|
||||||
return self.graph.as_default()
|
return self.graph.as_default()
|
||||||
else:
|
else:
|
||||||
return ModelV2.context(self)
|
return ModelV2.context(self)
|
||||||
|
|
||||||
def update_ops(self):
|
def update_ops(self) -> List[TensorType]:
|
||||||
"""Return the list of update ops for this model.
|
"""Return the list of update ops for this model.
|
||||||
|
|
||||||
For example, this should include any BatchNorm update ops."""
|
For example, this should include any BatchNorm update ops."""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def register_variables(self, variables):
|
def register_variables(self, variables: List[TensorType]) -> None:
|
||||||
"""Register the given list of variables with this model."""
|
"""Register the given list of variables with this model."""
|
||||||
self.var_list.extend(variables)
|
self.var_list.extend(variables)
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
def variables(self, as_dict=False):
|
def variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||||
if as_dict:
|
if as_dict:
|
||||||
return {v.name: v for v in self.var_list}
|
return {v.name: v for v in self.var_list}
|
||||||
return list(self.var_list)
|
return list(self.var_list)
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
def trainable_variables(self, as_dict=False):
|
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||||
if as_dict:
|
if as_dict:
|
||||||
return {
|
return {
|
||||||
k: v
|
k: v
|
||||||
|
|
|
@ -4,12 +4,14 @@ import numpy as np
|
||||||
import tree
|
import tree
|
||||||
|
|
||||||
from ray.rllib.models.action_dist import ActionDistribution
|
from ray.rllib.models.action_dist import ActionDistribution
|
||||||
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
|
from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
|
||||||
MAX_LOG_NN_OUTPUT
|
MAX_LOG_NN_OUTPUT
|
||||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||||
from ray.rllib.utils.torch_ops import atanh
|
from ray.rllib.utils.torch_ops import atanh
|
||||||
|
from ray.rllib.utils.types import TensorType, List
|
||||||
|
|
||||||
torch, nn = try_import_torch()
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
@ -18,7 +20,7 @@ class TorchDistributionWrapper(ActionDistribution):
|
||||||
"""Wrapper class for torch.distributions."""
|
"""Wrapper class for torch.distributions."""
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def __init__(self, inputs, model):
|
def __init__(self, inputs: List[TensorType], model: ModelV2):
|
||||||
if not isinstance(inputs, torch.Tensor):
|
if not isinstance(inputs, torch.Tensor):
|
||||||
inputs = torch.Tensor(inputs)
|
inputs = torch.Tensor(inputs)
|
||||||
super().__init__(inputs, model)
|
super().__init__(inputs, model)
|
||||||
|
@ -26,24 +28,24 @@ class TorchDistributionWrapper(ActionDistribution):
|
||||||
self.last_sample = None
|
self.last_sample = None
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def logp(self, actions):
|
def logp(self, actions: TensorType) -> TensorType:
|
||||||
return self.dist.log_prob(actions)
|
return self.dist.log_prob(actions)
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def entropy(self):
|
def entropy(self) -> TensorType:
|
||||||
return self.dist.entropy()
|
return self.dist.entropy()
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def kl(self, other):
|
def kl(self, other: ActionDistribution) -> TensorType:
|
||||||
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
|
return torch.distributions.kl.kl_divergence(self.dist, other.dist)
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def sample(self):
|
def sample(self) -> TensorType:
|
||||||
self.last_sample = self.dist.sample()
|
self.last_sample = self.dist.sample()
|
||||||
return self.last_sample
|
return self.last_sample
|
||||||
|
|
||||||
@override(ActionDistribution)
|
@override(ActionDistribution)
|
||||||
def sampled_action_logp(self):
|
def sampled_action_logp(self) -> TensorType:
|
||||||
assert self.last_sample is not None
|
assert self.last_sample is not None
|
||||||
return self.logp(self.last_sample)
|
return self.logp(self.last_sample)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
|
import gym
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
||||||
|
|
||||||
_, nn = try_import_torch()
|
_, nn = try_import_torch()
|
||||||
|
|
||||||
|
@ -12,8 +16,9 @@ class TorchModelV2(ModelV2):
|
||||||
Note that this class by itself is not a valid model unless you
|
Note that this class by itself is not a valid model unless you
|
||||||
inherit from nn.Module and implement forward() in a subclass."""
|
inherit from nn.Module and implement forward() in a subclass."""
|
||||||
|
|
||||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
def __init__(self, obs_space: gym.spaces.Space,
|
||||||
name):
|
action_space: gym.spaces.Space, num_outputs: int,
|
||||||
|
model_config: ModelConfigDict, name: str):
|
||||||
"""Initialize a TorchModelV2.
|
"""Initialize a TorchModelV2.
|
||||||
|
|
||||||
Here is an example implementation for a subclass
|
Here is an example implementation for a subclass
|
||||||
|
@ -42,13 +47,13 @@ class TorchModelV2(ModelV2):
|
||||||
framework="torch")
|
framework="torch")
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
def variables(self, as_dict=False):
|
def variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||||
if as_dict:
|
if as_dict:
|
||||||
return self.state_dict()
|
return self.state_dict()
|
||||||
return list(self.parameters())
|
return list(self.parameters())
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
def trainable_variables(self, as_dict=False):
|
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||||
if as_dict:
|
if as_dict:
|
||||||
return {
|
return {
|
||||||
k: v
|
k: v
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, List, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.utils.types import TensorType
|
from ray.rllib.utils.types import TensorType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.rllib.models import ModelV2
|
||||||
|
|
||||||
|
|
||||||
class ViewRequirement:
|
class ViewRequirement:
|
||||||
"""Single view requirement (for one column in a ModelV2 input_dict).
|
"""Single view requirement (for one column in a ModelV2 input_dict).
|
||||||
|
@ -57,10 +60,9 @@ class ViewRequirement:
|
||||||
# "time_major",
|
# "time_major",
|
||||||
|
|
||||||
|
|
||||||
def get_trajectory_view(
|
def get_trajectory_view(model: "ModelV2",
|
||||||
model,
|
trajectories: List["Trajectory"],
|
||||||
trajectories,
|
is_training: bool = False) -> Dict[str, TensorType]:
|
||||||
is_training: bool = False) -> Dict[str, TensorType]:
|
|
||||||
"""Returns an input_dict for a Model's forward pass given some data.
|
"""Returns an input_dict for a Model's forward pass given some data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
Loading…
Add table
Reference in a new issue