ray/rllib/utils/torch_ops.py
Sven Mika 428516056a
[RLlib] SAC Torch (incl. Atari learning) (#7984)
* Policy-classes cleanup and torch/tf unification.
- Make Policy abstract.
- Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch).
- Move some methods and vars to base Policy
  (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more.

* Fix `clip_action` import from Policy (should probably be moved into utils altogether).

* - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy).
- Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces).

* Add `config` to c'tor call to TFPolicy.

* Add missing `config` to c'tor call to TFPolicy in marvil_policy.py.

* Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract).

* Fix LINT errors in Policy classes.

* Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py.

* policy.py LINT errors.

* Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases).

* policy.py
- Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented).
- Fix docstring of `num_state_tensors`.

* Make QMIX torch Policy a child of TorchPolicy (instead of Policy).

* QMixPolicy add empty implementations of abstract Policy methods.

* Store Policy's config in self.config in base Policy c'tor.

* - Make only compute_actions in base Policy's an abstractmethod and provide pass
implementation to all other methods if not defined.
- Fix state_batches=None (most Policies don't have internal states).

* Cartpole tf learning.

* Cartpole tf AND torch learning (in ~ same ts).

* Cartpole tf AND torch learning (in ~ same ts). 2

* Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3

* Cartpole tf AND torch learning (in ~ same ts). 4

* Cartpole tf AND torch learning (in ~ same ts). 5

* Cartpole tf AND torch learning (in ~ same ts). 6

* Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning.

* WIP.

* WIP.

* SAC torch learning Pendulum.

* WIP.

* SAC torch and tf learning Pendulum and Cartpole after cleanup.

* WIP.

* LINT.

* LINT.

* SAC: Move policy.target_model to policy.device as well.

* Fixes and cleanup.

* Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default).

* Fixes and LINT.

* Fixes and LINT.

* Fix and LINT.

* WIP.

* Test fixes and LINT.

* Fixes and LINT.

Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00

110 lines
3.4 KiB
Python

import numpy as np
import logging
from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
try:
import tree
except (ImportError, ModuleNotFoundError) as e:
logger.warning("`dm-tree` is not installed! Run `pip install dm-tree`.")
raise e
def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return torch.where(
torch.abs(x) < delta,
torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta))
def reduce_mean_ignore_inf(x, axis):
"""Same as torch.mean() but ignores -inf values."""
mask = torch.ne(x, float("-inf"))
x_zeroed = torch.where(mask, x, torch.zeros_like(x))
return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
"""Minimized `objective` using `optimizer` w.r.t. variables in
`var_list` while ensure the norm of the gradients for each
variable is clipped to `clip_val`
"""
gradients = optimizer.compute_gradients(objective, var_list=var_list)
for i, (grad, var) in enumerate(gradients):
if grad is not None:
gradients[i] = (torch.nn.utils.clip_grad_norm_(grad, clip_val),
var)
return gradients
def sequence_mask(lengths, maxlen, dtype=None):
"""
Exact same behavior as tf.sequence_mask.
Thanks to Dimitris Papatheodorou
(https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
39036).
"""
if maxlen is None:
maxlen = lengths.max()
mask = ~(torch.ones((len(lengths), maxlen)).to(
lengths.device).cumsum(dim=1).t() > lengths).t()
mask.type(dtype or torch.bool)
return mask
def convert_to_non_torch_type(stats):
"""Converts values in `stats` to non-Tensor numpy or python types.
Args:
stats (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all torch tensors
being converted to numpy types.
Returns:
Any: A new struct with the same structure as `stats`, but with all
values converted to non-torch Tensor types.
"""
# The mapping function used to numpyize torch Tensors.
def mapping(item):
if isinstance(item, torch.Tensor):
return item.cpu().item() if len(item.size()) == 0 else \
item.cpu().detach().numpy()
else:
return item
return tree.map_structure(mapping, stats)
def convert_to_torch_tensor(stats, device=None):
"""Converts any struct to torch.Tensors.
stats (any): Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all leaves converted
to torch tensors.
Returns:
Any: A new struct with the same structure as `stats`, but with all
values converted to torch Tensor types.
"""
def mapping(item):
if torch.is_tensor(item):
return item if device is None else item.to(device)
tensor = torch.from_numpy(np.asarray(item))
# Floatify all float64 tensors.
if tensor.dtype == torch.double:
tensor = tensor.float()
return tensor if device is None else tensor.to(device)
return tree.map_structure(mapping, stats)
def atanh(x):
return 0.5 * torch.log((1 + x) / (1 - x))