ray/rllib/models/torch/torch_modelv2.py
Sven Mika 22ccc43670
[RLlib] DQN torch version. (#7597)
* Fix.

* Rollback.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix.

* Fix.

* Fix.

* WIP.

* WIP.

* Fix.

* Test case fixes.

* Test case fixes and LINT.

* Test case fixes and LINT.

* Rollback.

* WIP.

* WIP.

* Test case fixes.

* Fix.

* Fix.

* Fix.

* Add regression test for DQN w/ param noise.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Fixes and LINT.

* Comment

* Regression test case.

* WIP.

* WIP.

* LINT.

* LINT.

* WIP.

* Fix.

* Fix.

* Fix.

* LINT.

* Fix (SAC does currently not support eager).

* Fix.

* WIP.

* LINT.

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/evaluation/sampler.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/utils/exploration/exploration.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* WIP.

* Fix.

* LINT.

* LINT.

* Fix and LINT.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Fix.

* Fix and LINT.

* Update rllib/utils/exploration/exploration.py

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Update rllib/policy/dynamic_tf_policy.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Fixes.

* WIP.

* LINT.

* Fixes and LINT.

* LINT and fixes.

* LINT.

* Move action_dist back into torch extra_action_out_fn and LINT.

* Working SimpleQ learning cartpole on both torch AND tf.

* Working Rainbow learning cartpole on tf.

* Working Rainbow learning cartpole on tf.

* WIP.

* LINT.

* LINT.

* Update docs and add torch to APEX test.

* LINT.

* Fix.

* LINT.

* Fix.

* Fix.

* Fix and docstrings.

* Fix broken RLlib tests in master.

* Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier).

* Fix error_outputs option in BAZEL for RLlib regression tests.

* Fix.

* Tune param-noise tests.

* LINT.

* Fix.

* Fix.

* test

* test

* test

* Fix.

* Fix.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT.

* WIP.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 11:56:16 -07:00

91 lines
3.3 KiB
Python

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils import try_import_torch
_, nn = try_import_torch()
@PublicAPI
class TorchModelV2(ModelV2):
"""Torch version of ModelV2.
Note that this class by itself is not a valid model unless you
inherit from nn.Module and implement forward() in a subclass."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
"""Initialize a TorchModelV2.
Here is an example implementation for a subclass
``MyModelClass(TorchModelV2, nn.Module)``::
def __init__(self, *args, **kwargs):
TorchModelV2.__init__(self, *args, **kwargs)
nn.Module.__init__(self)
self._hidden_layers = nn.Sequential(...)
self._logits = ...
self._value_branch = ...
"""
if not isinstance(self, nn.Module):
raise ValueError(
"Subclasses of TorchModelV2 must also inherit from "
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")
ModelV2.__init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
framework="torch")
def forward(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by
__call__ before being passed to forward(). To access the flattened
observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution,
each call to forward() will eagerly evaluate the model. In symbolic
execution, each call to forward creates a computation graph that
operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
Args:
input_dict (dict): dictionary of input tensors, including "obs",
"obs_flat", "prev_action", "prev_reward", "is_training"
state (list): list of state tensors with sizes matching those
returned by get_initial_state + the batch dimension
seq_lens (Tensor): 1d tensor holding input sequence lengths
Returns:
(outputs, state): The model output tensor of size
[BATCH, num_outputs]
Examples:
>>> def forward(self, input_dict, state, seq_lens):
>>> features = self._hidden_layers(input_dict["obs"])
>>> self._value_out = self._value_branch(features)
>>> return self._logits(features), state
"""
raise NotImplementedError
@override(ModelV2)
def variables(self, as_dict=False):
if as_dict:
return self.state_dict()
return list(self.parameters())
@override(ModelV2)
def trainable_variables(self, as_dict=False):
if as_dict:
return {
k: v
for k, v in self.variables(as_dict=True).items()
if v.requires_grad
}
return [v for v in self.variables() if v.requires_grad]