mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Testing] Fix LINT/sphinx errors. (#8874)
This commit is contained in:
parent
ec5ecb661f
commit
0ba7472da9
11 changed files with 29 additions and 21 deletions
|
@ -14,7 +14,7 @@ pygments
|
|||
pyyaml
|
||||
recommonmark
|
||||
redis
|
||||
sphinx>2
|
||||
sphinx==3.0.4
|
||||
sphinx-click
|
||||
sphinx-copybutton
|
||||
sphinx-gallery
|
||||
|
|
|
@ -6,7 +6,7 @@ from ray.rllib.utils import try_import_torch
|
|||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class DQNTorchModel(TorchModelV2):
|
||||
class DQNTorchModel(TorchModelV2, nn.Module):
|
||||
"""Extension of standard TorchModelV2 to provide dueling-Q functionality.
|
||||
"""
|
||||
|
||||
|
@ -46,7 +46,7 @@ class DQNTorchModel(TorchModelV2):
|
|||
sigma0 (float): initial value of noisy nets
|
||||
add_layer_norm (bool): Enable layer norm (for param noise).
|
||||
"""
|
||||
|
||||
nn.Module.__init__(self)
|
||||
super(DQNTorchModel, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
|||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class SACTorchModel(TorchModelV2):
|
||||
class SACTorchModel(TorchModelV2, nn.Module):
|
||||
"""Extension of standard TorchModelV2 for SAC.
|
||||
|
||||
Data flow:
|
||||
|
@ -52,6 +52,7 @@ class SACTorchModel(TorchModelV2):
|
|||
only defines the layers for the output heads. Those layers for
|
||||
forward() should be defined in subclasses of SACModel.
|
||||
"""
|
||||
nn.Module.__init__(self)
|
||||
super(SACTorchModel, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class ParametricActionsModel(DistributionalQTFModel):
|
|||
return self.action_embed_model.value_function()
|
||||
|
||||
|
||||
class TorchParametricActionsModel(DQNTorchModel, nn.Module):
|
||||
class TorchParametricActionsModel(DQNTorchModel):
|
||||
"""PyTorch version of above ParametricActionsModel."""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -75,7 +75,6 @@ class TorchParametricActionsModel(DQNTorchModel, nn.Module):
|
|||
true_obs_shape=(4, ),
|
||||
action_embed_size=2,
|
||||
**kw):
|
||||
nn.Module.__init__(self)
|
||||
DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name, **kw)
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class RNNModel(RecurrentNetwork):
|
|||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
|
||||
class TorchRNNModel(TorchRNN):
|
||||
class TorchRNNModel(TorchRNN, nn.Module):
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
|
@ -84,6 +84,7 @@ class TorchRNNModel(TorchRNN):
|
|||
name,
|
||||
fc_size=64,
|
||||
lstm_state_size=256):
|
||||
nn.Module.__init__(self)
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
|
|
|
@ -15,14 +15,10 @@ from ray.rllib.models.tf.recurrent_net import LSTMWrapper
|
|||
from ray.rllib.models.tf.tf_action_dist import Categorical, \
|
||||
Deterministic, DiagGaussian, Dirichlet, \
|
||||
MultiActionDistribution, MultiCategorical
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
|
||||
from ray.rllib.models.torch.recurrent_net import LSTMWrapper as \
|
||||
TorchLSTMWrapper
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchDeterministic, TorchDiagGaussian, \
|
||||
TorchMultiActionDistribution, TorchMultiCategorical
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
|
@ -403,6 +399,8 @@ class ModelCatalog:
|
|||
default_model or ModelCatalog._get_v2_model_class(
|
||||
obs_space, model_config, framework=framework)
|
||||
if model_config.get("use_lstm"):
|
||||
from ray.rllib.models.torch.recurrent_net import LSTMWrapper \
|
||||
as TorchLSTMWrapper
|
||||
wrapped_cls = v2_class
|
||||
forward = wrapped_cls.forward
|
||||
v2_class = ModelCatalog._wrap_if_needed(
|
||||
|
@ -511,7 +509,7 @@ class ModelCatalog:
|
|||
|
||||
@staticmethod
|
||||
def _wrap_if_needed(model_cls, model_interface):
|
||||
assert issubclass(model_cls, (TFModelV2, TorchModelV2)), model_cls
|
||||
assert issubclass(model_cls, ModelV2), model_cls
|
||||
|
||||
if not model_interface or issubclass(model_cls, model_interface):
|
||||
return model_cls
|
||||
|
|
|
@ -13,13 +13,14 @@ torch, nn = try_import_torch()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FullyConnectedNetwork(TorchModelV2):
|
||||
class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
||||
"""Generic fully connected network."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
activation = get_activation_fn(
|
||||
model_config.get("fcnet_activation"), framework="torch")
|
||||
|
|
|
@ -18,9 +18,12 @@ class RecurrentNetwork(TorchModelV2):
|
|||
takes batches with the time dimension added already.
|
||||
|
||||
Here is an example implementation for a subclass
|
||||
``MyRNNClass(nn.Module, RecurrentNetwork)``::
|
||||
``MyRNNClass(RecurrentNetwork, nn.Module)``::
|
||||
|
||||
def __init__(self, obs_space, num_outputs):
|
||||
nn.Module.__init__(self)
|
||||
super().__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
self.obs_size = _get_size(obs_space)
|
||||
self.rnn_hidden_dim = model_config["lstm_cell_size"]
|
||||
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
|
||||
|
@ -87,15 +90,15 @@ class RecurrentNetwork(TorchModelV2):
|
|||
raise NotImplementedError("You must implement this for an RNN model")
|
||||
|
||||
|
||||
class LSTMWrapper(RecurrentNetwork):
|
||||
class LSTMWrapper(RecurrentNetwork, nn.Module):
|
||||
"""An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
|
||||
super(LSTMWrapper, self).__init__(obs_space, action_space, None,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
super().__init__(obs_space, action_space, None, model_config, name)
|
||||
|
||||
self.cell_size = model_config["lstm_cell_size"]
|
||||
self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True)
|
||||
|
|
|
@ -6,7 +6,7 @@ _, nn = try_import_torch()
|
|||
|
||||
|
||||
@PublicAPI
|
||||
class TorchModelV2(ModelV2, nn.Module):
|
||||
class TorchModelV2(ModelV2):
|
||||
"""Torch version of ModelV2.
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
|
@ -27,6 +27,11 @@ class TorchModelV2(ModelV2, nn.Module):
|
|||
self._value_branch = ...
|
||||
"""
|
||||
|
||||
if not isinstance(self, nn.Module):
|
||||
raise ValueError(
|
||||
"Subclasses of TorchModelV2 must also inherit from "
|
||||
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")
|
||||
|
||||
ModelV2.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
|
@ -35,7 +40,6 @@ class TorchModelV2(ModelV2, nn.Module):
|
|||
model_config,
|
||||
name,
|
||||
framework="torch")
|
||||
nn.Module.__init__(self)
|
||||
|
||||
@override(ModelV2)
|
||||
def variables(self, as_dict=False):
|
||||
|
|
|
@ -9,13 +9,14 @@ from ray.rllib.utils import try_import_torch
|
|||
_, nn = try_import_torch()
|
||||
|
||||
|
||||
class VisionNetwork(TorchModelV2):
|
||||
class VisionNetwork(TorchModelV2, nn.Module):
|
||||
"""Generic vision network."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
activation = get_activation_fn(
|
||||
model_config.get("conv_activation"), framework="torch")
|
||||
|
|
|
@ -299,7 +299,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
ModelCatalog.register_custom_model("invalid", InvalidModel)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"optimizer got an empty parameter list",
|
||||
"Subclasses of TorchModelV2 must also inherit from nn.Module",
|
||||
lambda: PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
|
|
Loading…
Add table
Reference in a new issue