2020-07-24 12:01:46 -07:00
|
|
|
import gym
|
2021-01-14 14:44:33 +01:00
|
|
|
from typing import Dict, List, Union
|
2020-07-24 12:01:46 -07:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
_, nn = try_import_torch()
|
2019-07-25 11:02:53 -07:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-25 11:02:53 -07:00
|
|
|
@PublicAPI
|
2020-06-10 15:41:59 +02:00
|
|
|
class TorchModelV2(ModelV2):
|
2021-01-12 12:33:57 +01:00
|
|
|
"""Torch version of ModelV2.
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-25 11:02:53 -07:00
|
|
|
Note that this class by itself is not a valid model unless you
|
2019-07-27 02:08:16 -07:00
|
|
|
inherit from nn.Module and implement forward() in a subclass."""
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
num_outputs: int,
|
|
|
|
model_config: ModelConfigDict,
|
|
|
|
name: str,
|
|
|
|
):
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Initialize a TorchModelV2.
|
|
|
|
|
|
|
|
Here is an example implementation for a subclass
|
|
|
|
``MyModelClass(TorchModelV2, nn.Module)``::
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
TorchModelV2.__init__(self, *args, **kwargs)
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
self._hidden_layers = nn.Sequential(...)
|
|
|
|
self._logits = ...
|
|
|
|
self._value_branch = ...
|
|
|
|
"""
|
|
|
|
|
2020-06-10 15:41:59 +02:00
|
|
|
if not isinstance(self, nn.Module):
|
|
|
|
raise ValueError(
|
|
|
|
"Subclasses of TorchModelV2 must also inherit from "
|
2022-01-29 18:41:57 -08:00
|
|
|
"nn.Module, e.g., MyModel(TorchModelV2, nn.Module)"
|
|
|
|
)
|
2020-06-10 15:41:59 +02:00
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
ModelV2.__init__(
|
|
|
|
self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
2019-07-25 11:02:53 -07:00
|
|
|
num_outputs,
|
2019-07-03 15:59:47 -07:00
|
|
|
model_config,
|
|
|
|
name,
|
2022-01-29 18:41:57 -08:00
|
|
|
framework="torch",
|
|
|
|
)
|
2019-07-27 02:08:16 -07:00
|
|
|
|
2021-10-04 13:29:00 +02:00
|
|
|
# Dict to store per multi-gpu tower stats into.
|
|
|
|
# In PyTorch multi-GPU, we use a single TorchPolicy and copy
|
|
|
|
# it's Model(s) n times (1 copy for each GPU). When computing the loss
|
|
|
|
# on each tower, we cannot store the stats (e.g. `entropy`) inside the
|
|
|
|
# policy object as this would lead to race conditions between the
|
|
|
|
# different towers all accessing the same property at the same time.
|
|
|
|
self.tower_stats = {}
|
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
@override(ModelV2)
|
2022-01-29 18:41:57 -08:00
|
|
|
def variables(
|
|
|
|
self, as_dict: bool = False
|
|
|
|
) -> Union[List[TensorType], Dict[str, TensorType]]:
|
2020-10-06 20:28:16 +02:00
|
|
|
p = list(self.parameters())
|
2020-04-06 20:56:16 +02:00
|
|
|
if as_dict:
|
2020-10-06 20:28:16 +02:00
|
|
|
return {k: p[i] for i, k in enumerate(self.state_dict().keys())}
|
|
|
|
return p
|
2020-04-06 20:56:16 +02:00
|
|
|
|
|
|
|
@override(ModelV2)
|
2022-01-29 18:41:57 -08:00
|
|
|
def trainable_variables(
|
|
|
|
self, as_dict: bool = False
|
|
|
|
) -> Union[List[TensorType], Dict[str, TensorType]]:
|
2020-04-06 20:56:16 +02:00
|
|
|
if as_dict:
|
|
|
|
return {
|
2022-01-29 18:41:57 -08:00
|
|
|
k: v for k, v in self.variables(as_dict=True).items() if v.requires_grad
|
2020-04-06 20:56:16 +02:00
|
|
|
}
|
|
|
|
return [v for v in self.variables() if v.requires_grad]
|