2020-08-26 04:24:05 -07:00
|
|
|
import numpy as np
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
|
2020-08-26 04:24:05 -07:00
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
|
|
# Custom initialization for different types of layers
|
|
|
|
if torch:
|
|
|
|
|
|
|
|
class Linear(nn.Linear):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
|
if self.bias is not None:
|
|
|
|
nn.init.zeros_(self.bias)
|
|
|
|
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
|
if self.bias is not None:
|
|
|
|
nn.init.zeros_(self.bias)
|
|
|
|
|
|
|
|
class ConvTranspose2d(nn.ConvTranspose2d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
|
if self.bias is not None:
|
|
|
|
nn.init.zeros_(self.bias)
|
|
|
|
|
|
|
|
class GRUCell(nn.GRUCell):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
nn.init.xavier_uniform_(self.weight_ih)
|
|
|
|
nn.init.orthogonal_(self.weight_hh)
|
|
|
|
nn.init.zeros_(self.bias_ih)
|
|
|
|
nn.init.zeros_(self.bias_hh)
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
# Custom Tanh Bijector due to big gradients through Dreamer Actor
|
2020-08-26 04:24:05 -07:00
|
|
|
class TanhBijector(torch.distributions.Transform):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
2021-08-18 18:47:08 +02:00
|
|
|
self.bijective = True
|
|
|
|
self.domain = torch.distributions.constraints.real
|
|
|
|
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)
|
|
|
|
|
2020-08-26 04:24:05 -07:00
|
|
|
def atanh(self, x):
|
|
|
|
return 0.5 * torch.log((1 + x) / (1 - x))
|
|
|
|
|
|
|
|
def sign(self):
|
|
|
|
return 1.0
|
|
|
|
|
|
|
|
def _call(self, x):
|
|
|
|
return torch.tanh(x)
|
|
|
|
|
|
|
|
def _inverse(self, y):
|
|
|
|
y = torch.where(
|
|
|
|
(torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-08-26 04:24:05 -07:00
|
|
|
y = self.atanh(y)
|
|
|
|
return y
|
|
|
|
|
|
|
|
def log_abs_det_jacobian(self, x, y):
|
|
|
|
return 2.0 * (np.log(2) - x - nn.functional.softplus(-2.0 * x))
|
|
|
|
|
|
|
|
|
|
|
|
# Modified from https://github.com/juliusfrost/dreamer-pytorch
|
|
|
|
class FreezeParameters:
|
|
|
|
def __init__(self, parameters):
|
|
|
|
self.parameters = parameters
|
|
|
|
self.param_states = [p.requires_grad for p in self.parameters]
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
for param in self.parameters:
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
for i, param in enumerate(self.parameters):
|
|
|
|
param.requires_grad = self.param_states[i]
|