mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
"""
|
|
[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
|
|
Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
|
|
https://arxiv.org/pdf/1706.03762.pdf
|
|
"""
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class MultiHeadAttention(tf.keras.layers.Layer if tf else object):
|
|
"""A multi-head attention layer described in [1]."""
|
|
|
|
def __init__(self, out_dim, num_heads, head_dim, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
# No bias or non-linearity.
|
|
self._num_heads = num_heads
|
|
self._head_dim = head_dim
|
|
self._qkv_layer = tf.keras.layers.Dense(
|
|
3 * num_heads * head_dim, use_bias=False)
|
|
self._linear_layer = tf.keras.layers.TimeDistributed(
|
|
tf.keras.layers.Dense(out_dim, use_bias=False))
|
|
|
|
def call(self, inputs):
|
|
L = tf.shape(inputs)[1] # length of segment
|
|
H = self._num_heads # number of attention heads
|
|
D = self._head_dim # attention head dimension
|
|
|
|
qkv = self._qkv_layer(inputs)
|
|
|
|
queries, keys, values = tf.split(qkv, 3, -1)
|
|
queries = queries[:, -L:] # only query based on the segment
|
|
|
|
queries = tf.reshape(queries, [-1, L, H, D])
|
|
keys = tf.reshape(keys, [-1, L, H, D])
|
|
values = tf.reshape(values, [-1, L, H, D])
|
|
|
|
score = tf.einsum("bihd,bjhd->bijh", queries, keys)
|
|
score = score / D**0.5
|
|
|
|
# causal mask of the same length as the sequence
|
|
mask = tf.sequence_mask(tf.range(1, L + 1), dtype=score.dtype)
|
|
mask = mask[None, :, :, None]
|
|
|
|
masked_score = score * mask + 1e30 * (mask - 1.)
|
|
wmat = tf.nn.softmax(masked_score, axis=2)
|
|
|
|
out = tf.einsum("bijh,bjhd->bihd", wmat, values)
|
|
shape = tf.concat([tf.shape(out)[:2], [H * D]], axis=0)
|
|
out = tf.reshape(out, shape)
|
|
return self._linear_layer(out)
|