mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class GRUGate(nn.Module):
|
|
"""Implements a gated recurrent unit for use in AttentionNet"""
|
|
|
|
def __init__(self, dim, init_bias=0., **kwargs):
|
|
"""
|
|
input_shape (torch.Tensor): dimension of the input
|
|
init_bias (int): Bias added to every input to stabilize training
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self._init_bias = init_bias
|
|
|
|
# Xavier initialization of torch tensors
|
|
self._w_r = torch.zeros(dim, dim)
|
|
self._w_z = torch.zeros(dim, dim)
|
|
self._w_h = torch.zeros(dim, dim)
|
|
|
|
self._u_r = torch.zeros(dim, dim)
|
|
self._u_z = torch.zeros(dim, dim)
|
|
self._u_h = torch.zeros(dim, dim)
|
|
|
|
nn.init.xavier_uniform_(self._w_r)
|
|
nn.init.xavier_uniform_(self._w_z)
|
|
nn.init.xavier_uniform_(self._w_h)
|
|
|
|
nn.init.xavier_uniform_(self._u_r)
|
|
nn.init.xavier_uniform_(self._u_z)
|
|
nn.init.xavier_uniform_(self._u_h)
|
|
|
|
self._bias_z = torch.zeros(dim, ).fill_(self._init_bias)
|
|
|
|
def forward(self, inputs, **kwargs):
|
|
# Pass in internal state first.
|
|
h, X = inputs
|
|
|
|
r = torch.tensordot(X, self._w_r, dims=1) + \
|
|
torch.tensordot(h, self._u_r, dims=1)
|
|
r = torch.sigmoid(r)
|
|
|
|
z = torch.tensordot(X, self._w_z, dims=1) + \
|
|
torch.tensordot(h, self._u_z, dims=1) - self._bias_z
|
|
z = torch.sigmoid(z)
|
|
|
|
h_next = torch.tensordot(X, self._w_h, dims=1) + \
|
|
torch.tensordot((h * r), self._u_h, dims=1)
|
|
h_next = torch.tanh(h_next)
|
|
|
|
return (1 - z) * h + z * h_next
|