from ray.rllib.utils.framework import try_import_torch import numpy as np torch, nn = try_import_torch() # Custom initialization for different types of layers if torch: class Linear(nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) if torch: class Conv2d(nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) if torch: class ConvTranspose2d(nn.ConvTranspose2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def reset_parameters(self): nn.init.xavier_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) if torch: class GRUCell(nn.GRUCell): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def reset_parameters(self): nn.init.xavier_uniform_(self.weight_ih) nn.init.orthogonal_(self.weight_hh) nn.init.zeros_(self.bias_ih) nn.init.zeros_(self.bias_hh) # Custom Tanh Bijector due to big gradients through Dreamer Actor if torch: class TanhBijector(torch.distributions.Transform): def __init__(self): super().__init__() def atanh(self, x): return 0.5 * torch.log((1 + x) / (1 - x)) def sign(self): return 1. def _call(self, x): return torch.tanh(x) def _inverse(self, y): y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) y = self.atanh(y) return y def log_abs_det_jacobian(self, x, y): return 2. * (np.log(2) - x - nn.functional.softplus(-2. * x)) # Modified from https://github.com/juliusfrost/dreamer-pytorch class FreezeParameters: def __init__(self, parameters): self.parameters = parameters self.param_states = [p.requires_grad for p in self.parameters] def __enter__(self): for param in self.parameters: param.requires_grad = False def __exit__(self, exc_type, exc_val, exc_tb): for i, param in enumerate(self.parameters): param.requires_grad = self.param_states[i]