ray/rllib/models/torch/misc.py
Sven Mika 428516056a
[RLlib] SAC Torch (incl. Atari learning) (#7984)
* Policy-classes cleanup and torch/tf unification.
- Make Policy abstract.
- Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch).
- Move some methods and vars to base Policy
  (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more.

* Fix `clip_action` import from Policy (should probably be moved into utils altogether).

* - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy).
- Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces).

* Add `config` to c'tor call to TFPolicy.

* Add missing `config` to c'tor call to TFPolicy in marvil_policy.py.

* Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract).

* Fix LINT errors in Policy classes.

* Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py.

* policy.py LINT errors.

* Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases).

* policy.py
- Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented).
- Fix docstring of `num_state_tensors`.

* Make QMIX torch Policy a child of TorchPolicy (instead of Policy).

* QMixPolicy add empty implementations of abstract Policy methods.

* Store Policy's config in self.config in base Policy c'tor.

* - Make only compute_actions in base Policy's an abstractmethod and provide pass
implementation to all other methods if not defined.
- Fix state_batches=None (most Policies don't have internal states).

* Cartpole tf learning.

* Cartpole tf AND torch learning (in ~ same ts).

* Cartpole tf AND torch learning (in ~ same ts). 2

* Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3

* Cartpole tf AND torch learning (in ~ same ts). 4

* Cartpole tf AND torch learning (in ~ same ts). 5

* Cartpole tf AND torch learning (in ~ same ts). 6

* Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning.

* WIP.

* WIP.

* SAC torch learning Pendulum.

* WIP.

* SAC torch and tf learning Pendulum and Cartpole after cleanup.

* WIP.

* LINT.

* LINT.

* SAC: Move policy.target_model to policy.device as well.

* Fixes and cleanup.

* Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default).

* Fixes and LINT.

* Fixes and LINT.

* Fix and LINT.

* WIP.

* Test fixes and LINT.

* Fixes and LINT.

Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00

108 lines
3.3 KiB
Python

""" Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
import numpy as np
from ray.rllib.utils import try_import_torch
torch, nn = try_import_torch()
def normc_initializer(std=1.0):
def initializer(tensor):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))
return initializer
def valid_padding(in_size, filter_size, stride_size):
"""Note: Padding is added to match TF conv2d `same` padding. See
www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution
Params:
in_size (tuple): Rows (Height), Column (Width) for input
stride_size (tuple): Rows (Height), Column (Width) for stride
filter_size (tuple): Rows (Height), Column (Width) for filter
Output:
padding (tuple): For input into torch.nn.ZeroPad2d
output (tuple): Output shape after padding and convolution
"""
in_height, in_width = in_size
filter_height, filter_width = filter_size
stride_height, stride_width = stride_size
out_height = np.ceil(float(in_height) / float(stride_height))
out_width = np.ceil(float(in_width) / float(stride_width))
pad_along_height = int(
((out_height - 1) * stride_height + filter_height - in_height))
pad_along_width = int(
((out_width - 1) * stride_width + filter_width - in_width))
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
padding = (pad_left, pad_right, pad_top, pad_bottom)
output = (out_height, out_width)
return padding, output
class SlimConv2d(nn.Module):
"""Simple mock of tf.slim Conv2d"""
def __init__(
self,
in_channels,
out_channels,
kernel,
stride,
padding,
# Defaulting these to nn.[..] will break soft torch import.
initializer="default",
activation_fn="default",
bias_init=0):
super(SlimConv2d, self).__init__()
layers = []
if padding:
layers.append(nn.ZeroPad2d(padding))
conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
if initializer:
if initializer == "default":
initializer = nn.init.xavier_uniform_
initializer(conv.weight)
nn.init.constant_(conv.bias, bias_init)
layers.append(conv)
if activation_fn:
if activation_fn == "default":
activation_fn = nn.ReLU
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)
class SlimFC(nn.Module):
"""Simple PyTorch version of `linear` function"""
def __init__(self,
in_size,
out_size,
initializer=None,
activation_fn=None,
bias_init=0.0):
super(SlimFC, self).__init__()
layers = []
linear = nn.Linear(in_size, out_size)
if initializer:
initializer(linear.weight)
nn.init.constant_(linear.bias, bias_init)
layers.append(linear)
if activation_fn:
layers.append(activation_fn())
self._model = nn.Sequential(*layers)
def forward(self, x):
return self._model(x)