ray/rllib/models/torch/modules/gru_gate.py

56 lines
2.1 KiB
Python

from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.framework import TensorType
torch, nn = try_import_torch()
class GRUGate(nn.Module):
"""Implements a gated recurrent unit for use in AttentionNet"""
def __init__(self, dim: int, init_bias: int = 0., **kwargs):
"""
input_shape (torch.Tensor): dimension of the input
init_bias (int): Bias added to every input to stabilize training
"""
super().__init__(**kwargs)
# Xavier initialization of torch tensors
self._w_r = nn.Parameter(torch.zeros(dim, dim))
self._w_z = nn.Parameter(torch.zeros(dim, dim))
self._w_h = nn.Parameter(torch.zeros(dim, dim))
nn.init.xavier_uniform_(self._w_r)
nn.init.xavier_uniform_(self._w_z)
nn.init.xavier_uniform_(self._w_h)
self.register_parameter("_w_r", self._w_r)
self.register_parameter("_w_z", self._w_z)
self.register_parameter("_w_h", self._w_h)
self._u_r = nn.Parameter(torch.zeros(dim, dim))
self._u_z = nn.Parameter(torch.zeros(dim, dim))
self._u_h = nn.Parameter(torch.zeros(dim, dim))
nn.init.xavier_uniform_(self._u_r)
nn.init.xavier_uniform_(self._u_z)
nn.init.xavier_uniform_(self._u_h)
self.register_parameter("_u_r", self._u_r)
self.register_parameter("_u_z", self._u_z)
self.register_parameter("_u_h", self._u_h)
self._bias_z = nn.Parameter(torch.zeros(dim, ).fill_(init_bias))
self.register_parameter("_bias_z", self._bias_z)
def forward(self, inputs: TensorType, **kwargs) -> TensorType:
# Pass in internal state first.
h, X = inputs
r = torch.tensordot(X, self._w_r, dims=1) + \
torch.tensordot(h, self._u_r, dims=1)
r = torch.sigmoid(r)
z = torch.tensordot(X, self._w_z, dims=1) + \
torch.tensordot(h, self._u_z, dims=1) - self._bias_z
z = torch.sigmoid(z)
h_next = torch.tensordot(X, self._w_h, dims=1) + \
torch.tensordot((h * r), self._u_h, dims=1)
h_next = torch.tanh(h_next)
return (1 - z) * h + z * h_next