[RLlib] Add DTTorchModel (#27872)

This commit is contained in:
Charles Sun 2022-08-16 18:18:29 -07:00 committed by GitHub
parent 87ce8480ff
commit 61880591e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 846 additions and 0 deletions

View file

@ -914,6 +914,13 @@ py_test(
srcs = ["algorithms/dt/tests/test_segmentation_buffer.py"]
)
py_test(
name = "test_dt_model",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/dt/tests/test_dt_model.py"]
)
# ES
py_test(
name = "test_es",

View file

@ -0,0 +1,242 @@
import gym
from gym.spaces import Discrete, Box
import numpy as np
from typing import Dict, List
from ray.rllib import SampleBatch
from ray.rllib.models import ModelV2
from ray.rllib.models.torch.mingpt import (
GPTConfig,
GPT,
)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import (
ModelConfigDict,
TensorType,
)
torch, nn = try_import_torch()
class DTTorchModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
self.obs_dim = num_outputs
if isinstance(action_space, Discrete):
self.action_dim = action_space.n
elif isinstance(action_space, Box):
self.action_dim = np.product(action_space.shape)
else:
raise NotImplementedError
# Common model parameters
self.embed_dim = self.model_config["embed_dim"]
self.max_seq_len = self.model_config["max_seq_len"]
self.max_ep_len = self.model_config["max_ep_len"]
self.block_size = self.model_config["max_seq_len"] * 3
# Build all the nn modules
self.transformer = self.build_transformer()
self.position_encoder = self.build_position_encoder()
self.action_encoder = self.build_action_encoder()
self.obs_encoder = self.build_obs_encoder()
self.return_encoder = self.build_return_encoder()
self.action_head = self.build_action_head()
self.obs_head = self.build_obs_head()
self.return_head = self.build_return_head()
# Update view requirement
# NOTE: See DTTorchPolicy.action_distribution_fn for an explanation of
# why the ViewRequirements are like this
self.view_requirements = {
SampleBatch.OBS: ViewRequirement(
space=obs_space, shift=f"-{self.max_seq_len-1}:0"
),
SampleBatch.ACTIONS: ViewRequirement(
space=action_space, shift=f"-{self.max_seq_len-1}:-1"
),
SampleBatch.REWARDS: ViewRequirement(shift=-1),
SampleBatch.T: ViewRequirement(shift=f"-{self.max_seq_len-2}:0"),
SampleBatch.RETURNS_TO_GO: ViewRequirement(
shift=f"-{self.max_seq_len-1}:-1"
),
}
def build_transformer(self):
# build the model
gpt_config = GPTConfig(
block_size=self.block_size,
n_layer=self.model_config["num_layers"],
n_head=self.model_config["num_heads"],
n_embed=self.embed_dim,
embed_pdrop=self.model_config["embed_pdrop"],
resid_pdrop=self.model_config["resid_pdrop"],
attn_pdrop=self.model_config["attn_pdrop"],
)
gpt = GPT(gpt_config)
return gpt
def build_position_encoder(self):
return nn.Embedding(self.max_ep_len, self.embed_dim)
def build_action_encoder(self):
if isinstance(self.action_space, Discrete):
return nn.Embedding(self.action_dim, self.embed_dim)
elif isinstance(self.action_space, Box):
return nn.Linear(self.action_dim, self.embed_dim)
else:
raise NotImplementedError
def build_obs_encoder(self):
return nn.Linear(self.obs_dim, self.embed_dim)
def build_return_encoder(self):
return nn.Linear(1, self.embed_dim)
def build_action_head(self):
return nn.Linear(self.embed_dim, self.action_dim)
def build_obs_head(self):
if not self.model_config["use_obs_output"]:
return None
return nn.Linear(self.embed_dim, self.obs_dim)
def build_return_head(self):
if not self.model_config["use_return_output"]:
return None
return nn.Linear(self.embed_dim, 1)
@override(ModelV2)
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> (TensorType, List[TensorType]):
# True No-op forward method.
# TODO: Support image observation inputs
return input_dict["obs"], state
def get_prediction(
self,
model_out: TensorType,
input_dict: SampleBatch,
return_attentions: bool = False,
) -> Dict[str, TensorType]:
"""Computes the output of a forward pass of the decision transformer.
Args:
model_out: output observation tensor from the base model, [B, T, obs_dim].
input_dict: a SampleBatch containing
RETURNS_TO_GO: [B, T (or T + 1), 1] of returns to go values.
ACTIONS: [B, T, action_dim] of actions.
T: [B, T] of timesteps.
ATTENTION_MASKS: [B, T] of attention masks.
return_attentions: Whether to return the attention tensors from the
transformer or not.
Returns:
A dictionary with keys and values:
ACTIONS: [B, T, action_dim] of predicted actions.
if return_attentions:
"attentions": List of attentions tensors from the transformer.
if model_config["use_obs_output"].
OBS: [B, T, obs_dim] of predicted observations.
if model_config["use_return_output"].
RETURNS_to_GO: [B, T, 1] of predicted returns to go.
"""
B, T, *_ = model_out.shape
obs_embeds = self.obs_encoder(model_out)
actions_embeds = self.action_encoder(input_dict[SampleBatch.ACTIONS])
# Note: rtg might have an extra element at the end for targets
# During training rtg will have T + 1 for its time dimension to get the
# rtg regression target. During evaluation/inference rtg will have T for
# its time dimension as we don't need to call get_targets.
returns_embeds = self.return_encoder(
input_dict[SampleBatch.RETURNS_TO_GO][:, :T, :]
)
timestep_embeds = self.position_encoder(input_dict[SampleBatch.T])
obs_embeds = obs_embeds + timestep_embeds
actions_embeds = actions_embeds + timestep_embeds
returns_embeds = returns_embeds + timestep_embeds
# This makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
stacked_inputs = torch.stack(
(returns_embeds, obs_embeds, actions_embeds), dim=2
).reshape(B, 3 * T, self.embed_dim)
attention_masks = input_dict[SampleBatch.ATTENTION_MASKS]
stacked_attention_masks = torch.stack(
(attention_masks, attention_masks, attention_masks), dim=2
).reshape(B, 3 * T)
# forward the transformer model
output_embeds = self.transformer(
stacked_inputs,
attention_masks=stacked_attention_masks,
return_attentions=return_attentions,
)
outputs = {}
if return_attentions:
output_embeds, attentions = output_embeds
outputs["attentions"] = attentions
# compute output heads
outputs[SampleBatch.ACTIONS] = self.action_head(output_embeds[:, 1::3, :])
if self.model_config["use_obs_output"]:
outputs[SampleBatch.OBS] = self.obs_head(output_embeds[:, 0::3, :])
if self.model_config["use_return_output"]:
outputs[SampleBatch.RETURNS_TO_GO] = self.return_head(
output_embeds[:, 2::3, :]
)
return outputs
def get_targets(
self, model_out: TensorType, input_dict: SampleBatch
) -> Dict[str, TensorType]:
"""Compute the target predictions for a given input_dict.
Args:
model_out: output observation tensor from the base model, [B, T, obs_dim].
input_dict: a SampleBatch containing
RETURNS_TO_GO: [B, T + 1, 1] of returns to go values.
ACTIONS: [B, T, action_dim] of actions.
T: [B, T] of timesteps.
ATTENTION_MASKS: [B, T] of attention masks.
Returns:
A dictionary with keys and values:
ACTIONS: [B, T, action_dim] of target actions.
if model_config["use_obs_output"]
OBS: [B, T, obs_dim] of target observations.
if model_config["use_return_output"]
RETURNS_to_GO: [B, T, 1] of target returns to go.
"""
targets = {SampleBatch.ACTIONS: input_dict[SampleBatch.ACTIONS].detach()}
if self.model_config["use_obs_output"]:
targets[SampleBatch.OBS] = model_out.detach()
if self.model_config["use_return_output"]:
targets[SampleBatch.RETURNS_TO_GO] = input_dict[SampleBatch.RETURNS_TO_GO][
:, 1:, :
].detach()
return targets

View file

@ -0,0 +1,301 @@
import unittest
import gym
import numpy as np
import ray
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.algorithms.dt.dt_torch_model import DTTorchModel
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
def _assert_outputs_equal(outputs):
for i in range(1, len(outputs)):
for key in outputs[0].keys():
assert np.allclose(
outputs[0][key], outputs[i][key]
), "outputs are different but they shouldn't be."
def _assert_outputs_not_equal(outputs):
for i in range(1, len(outputs)):
for key in outputs[0].keys():
assert not np.allclose(
outputs[0][key], outputs[i][key]
), "some outputs are the same but they shouldn't be."
def _generate_input_dict(B, T, obs_space, action_space):
"""Generate input_dict that has completely fake values."""
# generate deterministic inputs
# obs
obs = np.arange(B * T * obs_space.shape[0], dtype=np.float32).reshape(
(B, T, obs_space.shape[0])
)
# actions
if isinstance(action_space, gym.spaces.Box):
act = np.arange(B * T * action_space.shape[0], dtype=np.float32).reshape(
(B, T, action_space.shape[0])
)
else:
act = np.mod(np.arange(B * T, dtype=np.int32).reshape((B, T)), action_space.n)
# returns to go
rtg = np.arange(B * (T + 1), dtype=np.float32).reshape((B, T + 1, 1))
# timesteps
timesteps = np.stack([np.arange(T, dtype=np.int32) for _ in range(B)], axis=0)
# attention mask
mask = np.ones((B, T), dtype=np.float32)
input_dict = SampleBatch(
{
SampleBatch.OBS: obs,
SampleBatch.ACTIONS: act,
SampleBatch.RETURNS_TO_GO: rtg,
SampleBatch.T: timesteps,
SampleBatch.ATTENTION_MASKS: mask,
}
)
input_dict = convert_to_torch_tensor(input_dict)
return input_dict
class TestDTModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_torch_model_init(self):
"""Test models are initialized properly"""
model_config = {
"embed_dim": 32,
"num_layers": 2,
"max_seq_len": 4,
"max_ep_len": 10,
"num_heads": 2,
"embed_pdrop": 0.1,
"resid_pdrop": 0.1,
"attn_pdrop": 0.1,
"use_obs_output": False,
"use_return_output": False,
}
num_outputs = 2
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(num_outputs,))
action_dim = 5
action_spaces = [
gym.spaces.Box(-1.0, 1.0, shape=(action_dim,)),
gym.spaces.Discrete(action_dim),
]
B, T = 3, 4
for action_space in action_spaces:
# Generate input dict.
input_dict = _generate_input_dict(B, T, observation_space, action_space)
# Do random initialization a few times and make sure outputs are different
outputs = []
for _ in range(10):
model = DTTorchModel(
observation_space,
action_space,
num_outputs,
model_config,
"model",
)
# so dropout is not in effect
model.eval()
model_out, _ = model(input_dict)
output = model.get_prediction(model_out, input_dict)
outputs.append(convert_to_numpy(output))
_assert_outputs_not_equal(outputs)
# Initialize once and make sure dropout is working
model = DTTorchModel(
observation_space,
action_space,
num_outputs,
model_config,
"model",
)
# Dropout should make outputs different in training mode
model.train()
outputs = []
for _ in range(10):
model_out, _ = model(input_dict)
output = model.get_prediction(model_out, input_dict)
outputs.append(convert_to_numpy(output))
_assert_outputs_not_equal(outputs)
# Dropout should make outputs the same in eval mode
model.eval()
outputs = []
for _ in range(10):
model_out, _ = model(input_dict)
output = model.get_prediction(model_out, input_dict)
outputs.append(convert_to_numpy(output))
_assert_outputs_equal(outputs)
def test_torch_model_prediction_target(self):
"""Test the get_prediction and get_targets function."""
model_config = {
"embed_dim": 16,
"num_layers": 3,
"max_seq_len": 3,
"max_ep_len": 9,
"num_heads": 1,
"embed_pdrop": 0.2,
"resid_pdrop": 0.2,
"attn_pdrop": 0.2,
"use_obs_output": True,
"use_return_output": True,
}
num_outputs = 5
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(num_outputs,))
action_dim = 2
action_spaces = [
gym.spaces.Box(-1.0, 1.0, shape=(action_dim,)),
gym.spaces.Discrete(action_dim),
]
B, T = 2, 3
for action_space in action_spaces:
# Generate input dict.
input_dict = _generate_input_dict(B, T, observation_space, action_space)
# Make model and forward pass.
model = DTTorchModel(
observation_space,
action_space,
num_outputs,
model_config,
"model",
)
model_out, _ = model(input_dict)
preds = model.get_prediction(model_out, input_dict)
target = model.get_targets(model_out, input_dict)
preds = convert_to_numpy(preds)
target = convert_to_numpy(target)
# Test the content and shape of output and target
if isinstance(action_space, gym.spaces.Box):
# test preds shape
self.assertEqual(preds[SampleBatch.ACTIONS].shape, (B, T, action_dim))
# test target shape and content
self.assertEqual(target[SampleBatch.ACTIONS].shape, (B, T, action_dim))
assert np.allclose(
target[SampleBatch.ACTIONS],
input_dict[SampleBatch.ACTIONS],
)
else:
# test preds shape
self.assertEqual(preds[SampleBatch.ACTIONS].shape, (B, T, action_dim))
# test target shape and content
self.assertEqual(target[SampleBatch.ACTIONS].shape, (B, T))
assert np.allclose(
target[SampleBatch.ACTIONS],
input_dict[SampleBatch.ACTIONS],
)
# test preds shape
self.assertEqual(preds[SampleBatch.OBS].shape, (B, T, num_outputs))
# test target shape and content
self.assertEqual(target[SampleBatch.OBS].shape, (B, T, num_outputs))
assert np.allclose(
target[SampleBatch.OBS],
input_dict[SampleBatch.OBS],
)
# test preds shape
self.assertEqual(preds[SampleBatch.RETURNS_TO_GO].shape, (B, T, 1))
# test target shape and content
self.assertEqual(target[SampleBatch.RETURNS_TO_GO].shape, (B, T, 1))
assert np.allclose(
target[SampleBatch.RETURNS_TO_GO],
input_dict[SampleBatch.RETURNS_TO_GO][:, 1:, :],
)
def test_causal_masking(self):
"""Test that the transformer model' causal masking works."""
model_config = {
"embed_dim": 16,
"num_layers": 2,
"max_seq_len": 4,
"max_ep_len": 10,
"num_heads": 2,
"embed_pdrop": 0,
"resid_pdrop": 0,
"attn_pdrop": 0,
"use_obs_output": True,
"use_return_output": True,
}
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,))
action_space = gym.spaces.Box(-1.0, 1.0, shape=(2,))
B = 2
T = model_config["max_seq_len"]
# Generate input dict.
input_dict = _generate_input_dict(B, T, observation_space, action_space)
# make model and forward with attention
model = DTTorchModel(
observation_space,
action_space,
4,
model_config,
"model",
)
model_out, _ = model(input_dict)
preds = model.get_prediction(model_out, input_dict, return_attentions=True)
preds = convert_to_numpy(preds)
# test properties of attentions
attentions = preds["attentions"]
self.assertEqual(
len(attentions),
model_config["num_layers"],
"there should as many attention tensors as layers.",
)
# used to select the causal padded element of each attention tensor
select_mask = np.triu(np.ones((3 * T, 3 * T), dtype=np.bool), k=1)
select_mask = np.tile(select_mask, (B, model_config["num_heads"], 1, 1))
for attention in attentions:
# check shape
self.assertEqual(
attention.shape, (B, model_config["num_heads"], T * 3, T * 3)
)
# check the upper triangular masking
assert np.allclose(
attention[select_mask], 0.0
), "masked elements should be zero."
# check that the non-masked elements have non 0 scores
# Note: it is very unlikely that randomly initialized weights will make
# one of the scores be 0, as these scores are probabilities.
assert not np.any(
np.isclose(attention[np.logical_not(select_mask)], 0.0)
), "non masked elements should be nonzero."
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,296 @@
# LICENSE: MIT
"""
Adapted from https://github.com/karpathy/minGPT
Full definition of a GPT Language Model, all of it in this single file.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers
/models/gpt2/modeling_gpt2.py
"""
import math
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
@dataclass
class GPTConfig:
# block size must be provided
block_size: int
# transformer config
n_layer: int = 12
n_head: int = 12
n_embed: int = 768
# dropout config
embed_pdrop: float = 0.1
resid_pdrop: float = 0.1
attn_pdrop: float = 0.1
class NewGELU(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT
repo (identical to OpenAI GPT).
Reference: Gaussian Error Linear Units (GELU) paper:
https://arxiv.org/abs/1606.08415
"""
def forward(self, x):
return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)
class CausalSelfAttention(nn.Module):
"""
Vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embed % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
# output projection
self.c_proj = nn.Linear(config.n_embed, config.n_embed)
# regularization
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
# causal mask to ensure that attention is only applied to the left
# in the input sequence
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size)).view(
1, 1, config.block_size, config.block_size
),
)
self.n_head = config.n_head
self.n_embed = config.n_embed
def forward(self, x, attention_masks=None):
# batch size, sequence length, embedding dimensionality (n_embed)
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head
# forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
# (B, nh, T, hs)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# causal self-attention; Self-attend:
# (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
if attention_masks is not None:
att = att + attention_masks
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# re-assemble all head outputs side by side
y = y.transpose(1, 2).contiguous().view(B, T, C)
# output projection
y = self.resid_dropout(self.c_proj(y))
return y, att
class Block(nn.Module):
"""an unassuming Transformer block"""
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embed)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embed)
self.mlp = nn.ModuleDict(
dict(
c_fc=nn.Linear(config.n_embed, 4 * config.n_embed),
c_proj=nn.Linear(4 * config.n_embed, config.n_embed),
act=NewGELU(),
dropout=nn.Dropout(config.resid_pdrop),
)
)
m = self.mlp
# MLP forward
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
def forward(self, x, attention_masks=None):
x_att, att = self.attn(self.ln_1(x), attention_masks=attention_masks)
x = x + x_att
x = x + self.mlpf(self.ln_2(x))
return x, att
@DeveloperAPI
def configure_gpt_optimizer(
model: nn.Module,
learning_rate: float,
weight_decay: float,
betas: Tuple[float, float] = (0.9, 0.95),
**kwargs,
) -> torch.optim.Optimizer:
"""
This long function is unfortunately doing something very simple and is
being very defensive: We are separating out all parameters of the model
into two buckets: those that will experience weight decay for regularization
and those that won't (biases, and layernorm/embedding weights). We are then
returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience
# regularizing weight decay
decay = set()
no_decay = set()
whitelist_w_modules = (torch.nn.Linear,)
blacklist_w_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are
# recursive we will see the same tensors p many many times. but
# doing it this way allows us to know which parent module any
# tensor p belongs to...
if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_w_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_w_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
len(inter_params) == 0
), f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
assert len(param_dict.keys() - union_params) == 0, (
f"parameters {str(param_dict.keys() - union_params)} were not "
f"separated into either decay/no_decay set!"
)
# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(decay)],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **kwargs)
return optimizer
@DeveloperAPI
class GPT(nn.Module):
"""GPT Transformer Model"""
def __init__(self, config: GPTConfig):
super().__init__()
assert config.block_size is not None
self.block_size = config.block_size
self.transformer = nn.ModuleDict(
dict(
drop=nn.Dropout(config.embed_pdrop),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embed),
)
)
# init all weights, and apply a special scaled init to the residual
# projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, input_embeds, attention_masks=None, return_attentions=False):
"""
input_embeds: [batch_size x seq_len x n_embed]
attention_masks: [batch_size x seq_len], 0 don't attend, 1 attend
"""
B, T, C = input_embeds.size()
assert T <= self.block_size, (
f"Cannot forward sequence of length {T}, "
f"block size is only {self.block_size}"
)
if attention_masks is not None:
_B, _T = attention_masks.size()
assert _B == B and _T == T
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_len]
# So we can broadcast to
# [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular
# masking of causal attention used in OpenAI GPT, we just need
# to prepare the broadcast dimension here.
attention_masks = attention_masks[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend
# and 0.0 for masked positions, this operation will create a
# tensor which is 0.0 for positions we want to attend and -inf
# for masked positions. Since we are adding it to the raw scores
# before the softmax, this is effectively the same as removing
# these entirely.
attention_masks = attention_masks.to(dtype=input_embeds.dtype)
attention_masks = (1.0 - attention_masks) * -1e9
# forward the GPT model itself
x = self.transformer.drop(input_embeds)
atts = []
for block in self.transformer.h:
x, att = block(x, attention_masks=attention_masks)
atts.append(att)
x = self.transformer.ln_f(x)
if return_attentions:
return x, atts
else:
return x