ray/rllib/models/torch/visionnet.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

294 lines
10 KiB
Python
Raw Normal View History

import numpy as np
from typing import Dict, List
import gym
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import (
normc_initializer,
same_padding,
SlimConv2d,
SlimFC,
)
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
torch, nn = try_import_torch()
class VisionNetwork(TorchModelV2, nn.Module):
"""Generic vision network."""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
if not model_config.get("conv_filters"):
model_config["conv_filters"] = get_filter_config(obs_space.shape)
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
activation = self.model_config.get("conv_activation")
filters = self.model_config["conv_filters"]
assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!"
# Post FC net config.
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
post_fcnet_activation = get_activation_fn(
model_config.get("post_fcnet_activation"), framework="torch"
)
no_final_linear = self.model_config.get("no_final_linear")
vf_share_layers = self.model_config.get("vf_share_layers")
# Whether the last layer is the output of a Flattened (rather than
# a n x (1,1) Conv2D).
self.last_layer_is_flattened = False
self._logits = None
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
layers = []
(w, h, in_channels) = obs_space.shape
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = same_padding(in_size, kernel, stride)
layers.append(
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
SlimConv2d(
in_channels,
out_channels,
kernel,
stride,
padding,
activation_fn=activation,
)
)
in_channels = out_channels
in_size = out_size
out_channels, kernel, stride = filters[-1]
# No final linear: Last layer has activation function and exits with
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
# on `post_fcnet_...` settings).
if no_final_linear and num_outputs:
out_channels = out_channels if post_fcnet_hiddens else num_outputs
layers.append(
SlimConv2d(
in_channels,
out_channels,
kernel,
stride,
None, # padding=valid
activation_fn=activation,
)
)
# Add (optional) post-fc-stack after last Conv2D layer.
layer_sizes = post_fcnet_hiddens[:-1] + (
[num_outputs] if post_fcnet_hiddens else []
)
for i, out_size in enumerate(layer_sizes):
layers.append(
SlimFC(
in_size=out_channels,
out_size=out_size,
activation_fn=post_fcnet_activation,
initializer=normc_initializer(1.0),
)
)
out_channels = out_size
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
layers.append(
SlimConv2d(
in_channels,
out_channels,
kernel,
stride,
None, # padding=valid
activation_fn=activation,
)
)
# num_outputs defined. Use that to create an exact
# `num_output`-sized (1,1)-Conv2D.
if num_outputs:
in_size = [
np.ceil((in_size[0] - kernel[0]) / stride),
np.ceil((in_size[1] - kernel[1]) / stride),
]
padding, _ = same_padding(in_size, [1, 1], [1, 1])
if post_fcnet_hiddens:
layers.append(nn.Flatten())
in_size = out_channels
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens + [num_outputs]):
layers.append(
SlimFC(
in_size=in_size,
out_size=out_size,
activation_fn=post_fcnet_activation
if i < len(post_fcnet_hiddens) - 1
else None,
initializer=normc_initializer(1.0),
)
)
in_size = out_size
# Last layer is logits layer.
self._logits = layers.pop()
else:
self._logits = SlimConv2d(
out_channels,
num_outputs,
[1, 1],
1,
padding,
activation_fn=None,
)
# num_outputs not known -> Flatten, then set self.num_outputs
# to the resulting number of nodes.
else:
self.last_layer_is_flattened = True
layers.append(nn.Flatten())
self._convs = nn.Sequential(*layers)
# If our num_outputs still unknown, we need to do a test pass to
# figure out the output dimensions. This could be the case, if we have
# the Flatten layer at the end.
if self.num_outputs is None:
# Create a B=1 dummy sample and push it through out conv-net.
dummy_in = (
torch.from_numpy(self.obs_space.sample())
.permute(2, 0, 1)
.unsqueeze(0)
.float()
)
dummy_out = self._convs(dummy_in)
self.num_outputs = dummy_out.shape[1]
# Build the value layers
self._value_branch_separate = self._value_branch = None
if vf_share_layers:
self._value_branch = SlimFC(
out_channels, 1, initializer=normc_initializer(0.01), activation_fn=None
)
else:
vf_layers = []
(w, h, in_channels) = obs_space.shape
in_size = [w, h]
for out_channels, kernel, stride in filters[:-1]:
padding, out_size = same_padding(in_size, kernel, stride)
vf_layers.append(
SlimConv2d(
in_channels,
out_channels,
kernel,
stride,
padding,
activation_fn=activation,
)
)
in_channels = out_channels
in_size = out_size
out_channels, kernel, stride = filters[-1]
vf_layers.append(
SlimConv2d(
in_channels,
out_channels,
kernel,
stride,
None,
activation_fn=activation,
)
)
vf_layers.append(
SlimConv2d(
in_channels=out_channels,
out_channels=1,
kernel=1,
stride=1,
padding=None,
activation_fn=None,
)
)
self._value_branch_separate = nn.Sequential(*vf_layers)
# Holds the current "base" output (before logits layer).
self._features = None
@override(TorchModelV2)
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> (TensorType, List[TensorType]):
self._features = input_dict["obs"].float()
# Permuate b/c data comes in as [B, dim, dim, channels]:
self._features = self._features.permute(0, 3, 1, 2)
conv_out = self._convs(self._features)
# Store features to save forward pass when getting value_function out.
if not self._value_branch_separate:
self._features = conv_out
if not self.last_layer_is_flattened:
if self._logits:
conv_out = self._logits(conv_out)
if len(conv_out.shape) == 4:
if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, {} "
"(`num_outputs`), 1, 1] shape (but in {})! Please "
"adjust your Conv2D stack such that the last 2 dims "
"are both 1.".format(
self.model_config["conv_filters"],
self.num_outputs,
list(conv_out.shape),
)
)
logits = conv_out.squeeze(3)
logits = logits.squeeze(2)
else:
logits = conv_out
return logits, state
else:
return conv_out, state
@override(TorchModelV2)
def value_function(self) -> TensorType:
assert self._features is not None, "must call forward() first"
if self._value_branch_separate:
value = self._value_branch_separate(self._features)
value = value.squeeze(3)
value = value.squeeze(2)
return value.squeeze(1)
else:
if not self.last_layer_is_flattened:
features = self._features.squeeze(3)
features = features.squeeze(2)
else:
features = self._features
return self._value_branch(features).squeeze(1)
def _hidden_layers(self, obs: TensorType) -> TensorType:
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
res = res.squeeze(3)
res = res.squeeze(2)
return res