ray/rllib/models/torch/misc.py

188 lines
7.1 KiB
Python
Raw Normal View History

""" Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
import numpy as np
from typing import Union, Tuple, Any, List
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, nn = try_import_torch()
def normc_initializer(std: float = 1.0) -> Any:
def initializer(tensor):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))
return initializer
def same_padding(in_size: Tuple[int, int], filter_size: Tuple[int, int],
stride_size: Union[int, Tuple[int, int]]
) -> (Union[int, Tuple[int, int]], Tuple[int, int]):
"""Note: Padding is added to match TF conv2d `same` padding. See
www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution
Args:
in_size (tuple): Rows (Height), Column (Width) for input
stride_size (Union[int,Tuple[int, int]]): Rows (Height), column (Width)
for stride. If int, height == width.
filter_size (tuple): Rows (Height), column (Width) for filter
Returns:
padding (tuple): For input into torch.nn.ZeroPad2d.
output (tuple): Output shape after padding and convolution.
"""
in_height, in_width = in_size
if isinstance(filter_size, int):
filter_height, filter_width = filter_size, filter_size
else:
filter_height, filter_width = filter_size
if isinstance(stride_size, (int, float)):
stride_height, stride_width = int(stride_size), int(stride_size)
else:
stride_height, stride_width = int(stride_size[0]), int(stride_size[1])
out_height = np.ceil(float(in_height) / float(stride_height))
out_width = np.ceil(float(in_width) / float(stride_width))
pad_along_height = int(
((out_height - 1) * stride_height + filter_height - in_height))
pad_along_width = int(
((out_width - 1) * stride_width + filter_width - in_width))
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
padding = (pad_left, pad_right, pad_top, pad_bottom)
output = (out_height, out_width)
return padding, output
class SlimConv2d(nn.Module):
"""Simple mock of tf.slim Conv2d"""
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
def __init__(
self,
in_channels: int,
out_channels: int,
kernel: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]],
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
# Defaulting these to nn.[..] will break soft torch import.
initializer: Any = "default",
activation_fn: Any = "default",
bias_init: float = 0):
"""Creates a standard Conv2d layer, similar to torch.nn.Conv2d
Args:
in_channels(int): Number of input channels
out_channels (int): Number of output channels
kernel (Union[int, Tuple[int, int]]): If int, the kernel is
a tuple(x,x). Elsewise, the tuple can be specified
stride (Union[int, Tuple[int, int]]): Controls the stride
for the cross-correlation. If int, the stride is a
tuple(x,x). Elsewise, the tuple can be specified
padding (Union[int, Tuple[int, int]]): Controls the amount
of implicit zero-paddings during the conv operation
initializer (Any): Initializer function for kernel weights
activation_fn (Any): Activation function at the end of layer
bias_init (float): Initalize bias weights to bias_init const
"""
super(SlimConv2d, self).__init__()
layers = []
# Padding layer.
if padding:
layers.append(nn.ZeroPad2d(padding))
# Actual Conv2D layer (including correct initialization logic).
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
if initializer == "default":
initializer = nn.init.xavier_uniform_
initializer(conv.weight)
nn.init.constant_(conv.bias, bias_init)
layers.append(conv)
# Activation function (if any; default=ReLu).
if isinstance(activation_fn, str):
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
if activation_fn == "default":
activation_fn = nn.ReLU
else:
activation_fn = get_activation_fn(activation_fn, "torch")
if activation_fn is not None:
layers.append(activation_fn())
# Put everything in sequence.
self._model = nn.Sequential(*layers)
def forward(self, x: TensorType) -> TensorType:
return self._model(x)
class SlimFC(nn.Module):
"""Simple PyTorch version of `linear` function"""
def __init__(self,
in_size: int,
out_size: int,
initializer: Any = None,
activation_fn: Any = None,
use_bias: bool = True,
bias_init: float = 0.0):
"""Creates a standard FC layer, similar to torch.nn.Linear
Args:
in_size(int): Input size for FC Layer
out_size (int): Output size for FC Layer
initializer (Any): Initializer function for FC layer weights
activation_fn (Any): Activation function at the end of layer
use_bias (bool): Whether to add bias weights or not
bias_init (float): Initalize bias weights to bias_init const
"""
super(SlimFC, self).__init__()
layers = []
# Actual nn.Linear layer (including correct initialization logic).
2020-06-23 14:42:30 -04:00
linear = nn.Linear(in_size, out_size, bias=use_bias)
if initializer is None:
initializer = nn.init.xavier_uniform_
initializer(linear.weight)
2020-06-23 14:42:30 -04:00
if use_bias is True:
nn.init.constant_(linear.bias, bias_init)
layers.append(linear)
# Activation function (if any; default=None (linear)).
if isinstance(activation_fn, str):
activation_fn = get_activation_fn(activation_fn, "torch")
if activation_fn is not None:
layers.append(activation_fn())
# Put everything in sequence.
self._model = nn.Sequential(*layers)
def forward(self, x: TensorType) -> TensorType:
return self._model(x)
class AppendBiasLayer(nn.Module):
"""Simple bias appending layer for free_log_std."""
def __init__(self, num_bias_vars: int):
super().__init__()
self.log_std = torch.nn.Parameter(
torch.as_tensor([0.0] * num_bias_vars))
self.register_parameter("log_std", self.log_std)
def forward(self, x: TensorType) -> TensorType:
out = torch.cat(
[x, self.log_std.unsqueeze(0).repeat([len(x), 1])], axis=1)
return out
class Reshape(nn.Module):
"""Standard module that reshapes/views a tensor
"""
def __init__(self, shape: List):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)