[Testing] Fix LINT/sphinx errors. (#8874)

This commit is contained in:
Sven Mika 2020-06-10 15:41:59 +02:00 committed by GitHub
parent ec5ecb661f
commit 0ba7472da9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 29 additions and 21 deletions

View file

@ -14,7 +14,7 @@ pygments
pyyaml
recommonmark
redis
sphinx>2
sphinx==3.0.4
sphinx-click
sphinx-copybutton
sphinx-gallery

View file

@ -6,7 +6,7 @@ from ray.rllib.utils import try_import_torch
torch, nn = try_import_torch()
class DQNTorchModel(TorchModelV2):
class DQNTorchModel(TorchModelV2, nn.Module):
"""Extension of standard TorchModelV2 to provide dueling-Q functionality.
"""
@ -46,7 +46,7 @@ class DQNTorchModel(TorchModelV2):
sigma0 (float): initial value of noisy nets
add_layer_norm (bool): Enable layer norm (for param noise).
"""
nn.Module.__init__(self)
super(DQNTorchModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)

View file

@ -8,7 +8,7 @@ from ray.rllib.utils.framework import get_activation_fn, try_import_torch
torch, nn = try_import_torch()
class SACTorchModel(TorchModelV2):
class SACTorchModel(TorchModelV2, nn.Module):
"""Extension of standard TorchModelV2 for SAC.
Data flow:
@ -52,6 +52,7 @@ class SACTorchModel(TorchModelV2):
only defines the layers for the output heads. Those layers for
forward() should be defined in subclasses of SACModel.
"""
nn.Module.__init__(self)
super(SACTorchModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)

View file

@ -63,7 +63,7 @@ class ParametricActionsModel(DistributionalQTFModel):
return self.action_embed_model.value_function()
class TorchParametricActionsModel(DQNTorchModel, nn.Module):
class TorchParametricActionsModel(DQNTorchModel):
"""PyTorch version of above ParametricActionsModel."""
def __init__(self,
@ -75,7 +75,6 @@ class TorchParametricActionsModel(DQNTorchModel, nn.Module):
true_obs_shape=(4, ),
action_embed_size=2,
**kw):
nn.Module.__init__(self)
DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
model_config, name, **kw)

View file

@ -75,7 +75,7 @@ class RNNModel(RecurrentNetwork):
return tf.reshape(self._value_out, [-1])
class TorchRNNModel(TorchRNN):
class TorchRNNModel(TorchRNN, nn.Module):
def __init__(self,
obs_space,
action_space,
@ -84,6 +84,7 @@ class TorchRNNModel(TorchRNN):
name,
fc_size=64,
lstm_state_size=256):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config,
name)

View file

@ -15,14 +15,10 @@ from ray.rllib.models.tf.recurrent_net import LSTMWrapper
from ray.rllib.models.tf.tf_action_dist import Categorical, \
Deterministic, DiagGaussian, Dirichlet, \
MultiActionDistribution, MultiCategorical
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
from ray.rllib.models.torch.recurrent_net import LSTMWrapper as \
TorchLSTMWrapper
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
@ -403,6 +399,8 @@ class ModelCatalog:
default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
if model_config.get("use_lstm"):
from ray.rllib.models.torch.recurrent_net import LSTMWrapper \
as TorchLSTMWrapper
wrapped_cls = v2_class
forward = wrapped_cls.forward
v2_class = ModelCatalog._wrap_if_needed(
@ -511,7 +509,7 @@ class ModelCatalog:
@staticmethod
def _wrap_if_needed(model_cls, model_interface):
assert issubclass(model_cls, (TFModelV2, TorchModelV2)), model_cls
assert issubclass(model_cls, ModelV2), model_cls
if not model_interface or issubclass(model_cls, model_interface):
return model_cls

View file

@ -13,13 +13,14 @@ torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
class FullyConnectedNetwork(TorchModelV2):
class FullyConnectedNetwork(TorchModelV2, nn.Module):
"""Generic fully connected network."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
activation = get_activation_fn(
model_config.get("fcnet_activation"), framework="torch")

View file

@ -18,9 +18,12 @@ class RecurrentNetwork(TorchModelV2):
takes batches with the time dimension added already.
Here is an example implementation for a subclass
``MyRNNClass(nn.Module, RecurrentNetwork)``::
``MyRNNClass(RecurrentNetwork, nn.Module)``::
def __init__(self, obs_space, num_outputs):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs,
model_config, name)
self.obs_size = _get_size(obs_space)
self.rnn_hidden_dim = model_config["lstm_cell_size"]
self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
@ -87,15 +90,15 @@ class RecurrentNetwork(TorchModelV2):
raise NotImplementedError("You must implement this for an RNN model")
class LSTMWrapper(RecurrentNetwork):
class LSTMWrapper(RecurrentNetwork, nn.Module):
"""An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(LSTMWrapper, self).__init__(obs_space, action_space, None,
model_config, name)
nn.Module.__init__(self)
super().__init__(obs_space, action_space, None, model_config, name)
self.cell_size = model_config["lstm_cell_size"]
self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True)

View file

@ -6,7 +6,7 @@ _, nn = try_import_torch()
@PublicAPI
class TorchModelV2(ModelV2, nn.Module):
class TorchModelV2(ModelV2):
"""Torch version of ModelV2.
Note that this class by itself is not a valid model unless you
@ -27,6 +27,11 @@ class TorchModelV2(ModelV2, nn.Module):
self._value_branch = ...
"""
if not isinstance(self, nn.Module):
raise ValueError(
"Subclasses of TorchModelV2 must also inherit from "
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)")
ModelV2.__init__(
self,
obs_space,
@ -35,7 +40,6 @@ class TorchModelV2(ModelV2, nn.Module):
model_config,
name,
framework="torch")
nn.Module.__init__(self)
@override(ModelV2)
def variables(self, as_dict=False):

View file

@ -9,13 +9,14 @@ from ray.rllib.utils import try_import_torch
_, nn = try_import_torch()
class VisionNetwork(TorchModelV2):
class VisionNetwork(TorchModelV2, nn.Module):
"""Generic vision network."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
activation = get_activation_fn(
model_config.get("conv_activation"), framework="torch")

View file

@ -299,7 +299,7 @@ class NestedSpacesTest(unittest.TestCase):
ModelCatalog.register_custom_model("invalid", InvalidModel)
self.assertRaisesRegexp(
ValueError,
"optimizer got an empty parameter list",
"Subclasses of TorchModelV2 must also inherit from nn.Module",
lambda: PGTrainer(
env="CartPole-v0",
config={