mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class GRUGate(tf.keras.layers.Layer if tf else object):
|
|
def __init__(self, init_bias=0., **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._init_bias = init_bias
|
|
|
|
def build(self, input_shape):
|
|
h_shape, x_shape = input_shape
|
|
if x_shape[-1] != h_shape[-1]:
|
|
raise ValueError(
|
|
"Both inputs to GRUGate must have equal size in last axis!")
|
|
|
|
dim = int(h_shape[-1])
|
|
self._w_r = self.add_weight(shape=(dim, dim))
|
|
self._w_z = self.add_weight(shape=(dim, dim))
|
|
self._w_h = self.add_weight(shape=(dim, dim))
|
|
|
|
self._u_r = self.add_weight(shape=(dim, dim))
|
|
self._u_z = self.add_weight(shape=(dim, dim))
|
|
self._u_h = self.add_weight(shape=(dim, dim))
|
|
|
|
def bias_initializer(shape, dtype):
|
|
return tf.fill(shape, tf.cast(self._init_bias, dtype=dtype))
|
|
|
|
self._bias_z = self.add_weight(
|
|
shape=(dim, ), initializer=bias_initializer)
|
|
|
|
def call(self, inputs, **kwargs):
|
|
# Pass in internal state first.
|
|
h, X = inputs
|
|
|
|
r = tf.tensordot(X, self._w_r, axes=1) + \
|
|
tf.tensordot(h, self._u_r, axes=1)
|
|
r = tf.nn.sigmoid(r)
|
|
|
|
z = tf.tensordot(X, self._w_z, axes=1) + \
|
|
tf.tensordot(h, self._u_z, axes=1) - self._bias_z
|
|
z = tf.nn.sigmoid(z)
|
|
|
|
h_next = tf.tensordot(X, self._w_h, axes=1) + \
|
|
tf.tensordot((h * r), self._u_h, axes=1)
|
|
h_next = tf.nn.tanh(h_next)
|
|
|
|
return (1 - z) * h + z * h_next
|