mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
91 lines
3.3 KiB
Python
91 lines
3.3 KiB
Python
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils import try_import_torch
|
|
|
|
_, nn = try_import_torch()
|
|
|
|
|
|
@PublicAPI
|
|
class TorchModelV2(ModelV2):
|
|
"""Torch version of ModelV2.
|
|
|
|
Note that this class by itself is not a valid model unless you
|
|
inherit from nn.Module and implement forward() in a subclass."""
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name):
|
|
"""Initialize a TorchModelV2.
|
|
|
|
Here is an example implementation for a subclass
|
|
``MyModelClass(TorchModelV2, nn.Module)``::
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
TorchModelV2.__init__(self, *args, **kwargs)
|
|
nn.Module.__init__(self)
|
|
self._hidden_layers = nn.Sequential(...)
|
|
self._logits = ...
|
|
self._value_branch = ...
|
|
"""
|
|
|
|
if not isinstance(self, nn.Module):
|
|
raise ValueError(
|
|
"Subclasses of TorchModelV2 must also inherit from "
|
|
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")
|
|
|
|
ModelV2.__init__(
|
|
self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
framework="torch")
|
|
|
|
def forward(self, input_dict, state, seq_lens):
|
|
"""Call the model with the given input tensors and state.
|
|
|
|
Any complex observations (dicts, tuples, etc.) will be unpacked by
|
|
__call__ before being passed to forward(). To access the flattened
|
|
observation tensor, refer to input_dict["obs_flat"].
|
|
|
|
This method can be called any number of times. In eager execution,
|
|
each call to forward() will eagerly evaluate the model. In symbolic
|
|
execution, each call to forward creates a computation graph that
|
|
operates over the variables of this model (i.e., shares weights).
|
|
|
|
Custom models should override this instead of __call__.
|
|
|
|
Args:
|
|
input_dict (dict): dictionary of input tensors, including "obs",
|
|
"obs_flat", "prev_action", "prev_reward", "is_training"
|
|
state (list): list of state tensors with sizes matching those
|
|
returned by get_initial_state + the batch dimension
|
|
seq_lens (Tensor): 1d tensor holding input sequence lengths
|
|
|
|
Returns:
|
|
(outputs, state): The model output tensor of size
|
|
[BATCH, num_outputs]
|
|
|
|
Examples:
|
|
>>> def forward(self, input_dict, state, seq_lens):
|
|
>>> features = self._hidden_layers(input_dict["obs"])
|
|
>>> self._value_out = self._value_branch(features)
|
|
>>> return self._logits(features), state
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@override(ModelV2)
|
|
def variables(self, as_dict=False):
|
|
if as_dict:
|
|
return self.state_dict()
|
|
return list(self.parameters())
|
|
|
|
@override(ModelV2)
|
|
def trainable_variables(self, as_dict=False):
|
|
if as_dict:
|
|
return {
|
|
k: v
|
|
for k, v in self.variables(as_dict=True).items()
|
|
if v.requires_grad
|
|
}
|
|
return [v for v in self.variables() if v.requires_grad]
|