mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
99 lines
3 KiB
Python
99 lines
3 KiB
Python
import numpy as np
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
# Custom initialization for different types of layers
|
|
if torch:
|
|
|
|
class Linear(nn.Linear):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.weight)
|
|
if self.bias is not None:
|
|
nn.init.zeros_(self.bias)
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.weight)
|
|
if self.bias is not None:
|
|
nn.init.zeros_(self.bias)
|
|
|
|
class ConvTranspose2d(nn.ConvTranspose2d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.weight)
|
|
if self.bias is not None:
|
|
nn.init.zeros_(self.bias)
|
|
|
|
class GRUCell(nn.GRUCell):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.weight_ih)
|
|
nn.init.orthogonal_(self.weight_hh)
|
|
nn.init.zeros_(self.bias_ih)
|
|
nn.init.zeros_(self.bias_hh)
|
|
|
|
# Custom Tanh Bijector due to big gradients through Dreamer Actor
|
|
class TanhBijector(torch.distributions.Transform):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.bijective = True
|
|
self.domain = torch.distributions.constraints.real
|
|
self.codomain = torch.distributions.constraints.interval(-1.0, 1.0)
|
|
|
|
def atanh(self, x):
|
|
return 0.5 * torch.log((1 + x) / (1 - x))
|
|
|
|
def sign(self):
|
|
return 1.0
|
|
|
|
def _call(self, x):
|
|
return torch.tanh(x)
|
|
|
|
def _inverse(self, y):
|
|
y = torch.where(
|
|
(torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y
|
|
)
|
|
y = self.atanh(y)
|
|
return y
|
|
|
|
def log_abs_det_jacobian(self, x, y):
|
|
return 2.0 * (np.log(2) - x - nn.functional.softplus(-2.0 * x))
|
|
|
|
|
|
# Modified from https://github.com/juliusfrost/dreamer-pytorch
|
|
class FreezeParameters:
|
|
def __init__(self, parameters):
|
|
self.parameters = parameters
|
|
self.param_states = [p.requires_grad for p in self.parameters]
|
|
|
|
def __enter__(self):
|
|
for param in self.parameters:
|
|
param.requires_grad = False
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
for i, param in enumerate(self.parameters):
|
|
param.requires_grad = self.param_states[i]
|
|
|
|
|
|
def batchify_states(states_list, batch_size, device=None):
|
|
"""
|
|
Batchify data into batches of size batch_size
|
|
"""
|
|
state_batches = [s[None, :].expand(batch_size, -1) for s in states_list]
|
|
if device is not None:
|
|
state_batches = [s.to(device) for s in state_batches]
|
|
|
|
return state_batches
|