2020-04-06 20:56:16 +02:00
|
|
|
import numpy as np
|
2020-03-22 21:51:24 +01:00
|
|
|
import logging
|
2020-03-06 21:45:30 +01:00
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
2020-03-22 21:51:24 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
try:
|
|
|
|
import tree
|
|
|
|
except (ImportError, ModuleNotFoundError) as e:
|
|
|
|
logger.warning("`dm-tree` is not installed! Run `pip install dm-tree`.")
|
|
|
|
raise e
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
def huber_loss(x, delta=1.0):
|
|
|
|
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
|
|
|
return torch.where(
|
|
|
|
torch.abs(x) < delta,
|
|
|
|
torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta))
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_mean_ignore_inf(x, axis):
|
|
|
|
"""Same as torch.mean() but ignores -inf values."""
|
|
|
|
mask = torch.ne(x, float("-inf"))
|
|
|
|
x_zeroed = torch.where(mask, x, torch.zeros_like(x))
|
|
|
|
return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
|
|
|
|
|
|
|
|
|
|
|
|
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
|
|
|
|
"""Minimized `objective` using `optimizer` w.r.t. variables in
|
|
|
|
`var_list` while ensure the norm of the gradients for each
|
|
|
|
variable is clipped to `clip_val`
|
|
|
|
"""
|
|
|
|
gradients = optimizer.compute_gradients(objective, var_list=var_list)
|
|
|
|
for i, (grad, var) in enumerate(gradients):
|
|
|
|
if grad is not None:
|
|
|
|
gradients[i] = (torch.nn.utils.clip_grad_norm_(grad, clip_val),
|
|
|
|
var)
|
|
|
|
return gradients
|
|
|
|
|
|
|
|
|
2020-02-22 20:02:31 +01:00
|
|
|
def sequence_mask(lengths, maxlen, dtype=None):
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
|
|
|
Exact same behavior as tf.sequence_mask.
|
|
|
|
Thanks to Dimitris Papatheodorou
|
2020-02-11 00:22:07 +01:00
|
|
|
(https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
|
|
|
|
39036).
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
|
|
|
if maxlen is None:
|
|
|
|
maxlen = lengths.max()
|
|
|
|
|
2020-04-15 07:21:51 +02:00
|
|
|
mask = ~(torch.ones((len(lengths), maxlen)).to(
|
|
|
|
lengths.device).cumsum(dim=1).t() > lengths).t()
|
2020-02-22 20:02:31 +01:00
|
|
|
mask.type(dtype or torch.bool)
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
return mask
|
2020-02-22 20:02:31 +01:00
|
|
|
|
|
|
|
|
2020-03-06 21:45:30 +01:00
|
|
|
def convert_to_non_torch_type(stats):
|
2020-04-06 20:56:16 +02:00
|
|
|
"""Converts values in `stats` to non-Tensor numpy or python types.
|
2020-02-22 20:02:31 +01:00
|
|
|
|
|
|
|
Args:
|
2020-03-06 21:45:30 +01:00
|
|
|
stats (any): Any (possibly nested) struct, the values in which will be
|
|
|
|
converted and returned as a new struct with all torch tensors
|
|
|
|
being converted to numpy types.
|
2020-02-22 20:02:31 +01:00
|
|
|
|
|
|
|
Returns:
|
2020-04-06 20:56:16 +02:00
|
|
|
Any: A new struct with the same structure as `stats`, but with all
|
2020-02-22 20:02:31 +01:00
|
|
|
values converted to non-torch Tensor types.
|
|
|
|
"""
|
2020-03-22 21:51:24 +01:00
|
|
|
|
2020-03-06 21:45:30 +01:00
|
|
|
# The mapping function used to numpyize torch Tensors.
|
|
|
|
def mapping(item):
|
|
|
|
if isinstance(item, torch.Tensor):
|
|
|
|
return item.cpu().item() if len(item.size()) == 0 else \
|
2020-04-06 20:56:16 +02:00
|
|
|
item.cpu().detach().numpy()
|
2020-02-22 20:02:31 +01:00
|
|
|
else:
|
2020-03-06 21:45:30 +01:00
|
|
|
return item
|
|
|
|
|
|
|
|
return tree.map_structure(mapping, stats)
|
2020-04-06 20:56:16 +02:00
|
|
|
|
|
|
|
|
|
|
|
def convert_to_torch_tensor(stats, device=None):
|
|
|
|
"""Converts any struct to torch.Tensors.
|
|
|
|
|
|
|
|
stats (any): Any (possibly nested) struct, the values in which will be
|
|
|
|
converted and returned as a new struct with all leaves converted
|
|
|
|
to torch tensors.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: A new struct with the same structure as `stats`, but with all
|
|
|
|
values converted to torch Tensor types.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def mapping(item):
|
|
|
|
if torch.is_tensor(item):
|
|
|
|
return item if device is None else item.to(device)
|
|
|
|
tensor = torch.from_numpy(np.asarray(item))
|
|
|
|
# Floatify all float64 tensors.
|
|
|
|
if tensor.dtype == torch.double:
|
|
|
|
tensor = tensor.float()
|
|
|
|
return tensor if device is None else tensor.to(device)
|
|
|
|
|
|
|
|
return tree.map_structure(mapping, stats)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
|
|
|
|
def atanh(x):
|
|
|
|
return 0.5 * torch.log((1 + x) / (1 - x))
|