mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Add DTTorchModel (#27872)
This commit is contained in:
parent
87ce8480ff
commit
61880591e9
4 changed files with 846 additions and 0 deletions
|
@ -914,6 +914,13 @@ py_test(
|
|||
srcs = ["algorithms/dt/tests/test_segmentation_buffer.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_dt_model",
|
||||
tags = ["team:rllib", "algorithms_dir"],
|
||||
size = "medium",
|
||||
srcs = ["algorithms/dt/tests/test_dt_model.py"]
|
||||
)
|
||||
|
||||
# ES
|
||||
py_test(
|
||||
name = "test_es",
|
||||
|
|
242
rllib/algorithms/dt/dt_torch_model.py
Normal file
242
rllib/algorithms/dt/dt_torch_model.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
import gym
|
||||
from gym.spaces import Discrete, Box
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
|
||||
from ray.rllib import SampleBatch
|
||||
from ray.rllib.models import ModelV2
|
||||
from ray.rllib.models.torch.mingpt import (
|
||||
GPTConfig,
|
||||
GPT,
|
||||
)
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils import override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import (
|
||||
ModelConfigDict,
|
||||
TensorType,
|
||||
)
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class DTTorchModel(TorchModelV2, nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
):
|
||||
TorchModelV2.__init__(
|
||||
self, obs_space, action_space, num_outputs, model_config, name
|
||||
)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.obs_dim = num_outputs
|
||||
|
||||
if isinstance(action_space, Discrete):
|
||||
self.action_dim = action_space.n
|
||||
elif isinstance(action_space, Box):
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Common model parameters
|
||||
self.embed_dim = self.model_config["embed_dim"]
|
||||
self.max_seq_len = self.model_config["max_seq_len"]
|
||||
self.max_ep_len = self.model_config["max_ep_len"]
|
||||
self.block_size = self.model_config["max_seq_len"] * 3
|
||||
|
||||
# Build all the nn modules
|
||||
self.transformer = self.build_transformer()
|
||||
self.position_encoder = self.build_position_encoder()
|
||||
self.action_encoder = self.build_action_encoder()
|
||||
self.obs_encoder = self.build_obs_encoder()
|
||||
self.return_encoder = self.build_return_encoder()
|
||||
self.action_head = self.build_action_head()
|
||||
self.obs_head = self.build_obs_head()
|
||||
self.return_head = self.build_return_head()
|
||||
|
||||
# Update view requirement
|
||||
# NOTE: See DTTorchPolicy.action_distribution_fn for an explanation of
|
||||
# why the ViewRequirements are like this
|
||||
self.view_requirements = {
|
||||
SampleBatch.OBS: ViewRequirement(
|
||||
space=obs_space, shift=f"-{self.max_seq_len-1}:0"
|
||||
),
|
||||
SampleBatch.ACTIONS: ViewRequirement(
|
||||
space=action_space, shift=f"-{self.max_seq_len-1}:-1"
|
||||
),
|
||||
SampleBatch.REWARDS: ViewRequirement(shift=-1),
|
||||
SampleBatch.T: ViewRequirement(shift=f"-{self.max_seq_len-2}:0"),
|
||||
SampleBatch.RETURNS_TO_GO: ViewRequirement(
|
||||
shift=f"-{self.max_seq_len-1}:-1"
|
||||
),
|
||||
}
|
||||
|
||||
def build_transformer(self):
|
||||
# build the model
|
||||
gpt_config = GPTConfig(
|
||||
block_size=self.block_size,
|
||||
n_layer=self.model_config["num_layers"],
|
||||
n_head=self.model_config["num_heads"],
|
||||
n_embed=self.embed_dim,
|
||||
embed_pdrop=self.model_config["embed_pdrop"],
|
||||
resid_pdrop=self.model_config["resid_pdrop"],
|
||||
attn_pdrop=self.model_config["attn_pdrop"],
|
||||
)
|
||||
gpt = GPT(gpt_config)
|
||||
return gpt
|
||||
|
||||
def build_position_encoder(self):
|
||||
return nn.Embedding(self.max_ep_len, self.embed_dim)
|
||||
|
||||
def build_action_encoder(self):
|
||||
if isinstance(self.action_space, Discrete):
|
||||
return nn.Embedding(self.action_dim, self.embed_dim)
|
||||
elif isinstance(self.action_space, Box):
|
||||
return nn.Linear(self.action_dim, self.embed_dim)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def build_obs_encoder(self):
|
||||
return nn.Linear(self.obs_dim, self.embed_dim)
|
||||
|
||||
def build_return_encoder(self):
|
||||
return nn.Linear(1, self.embed_dim)
|
||||
|
||||
def build_action_head(self):
|
||||
return nn.Linear(self.embed_dim, self.action_dim)
|
||||
|
||||
def build_obs_head(self):
|
||||
if not self.model_config["use_obs_output"]:
|
||||
return None
|
||||
return nn.Linear(self.embed_dim, self.obs_dim)
|
||||
|
||||
def build_return_head(self):
|
||||
if not self.model_config["use_return_output"]:
|
||||
return None
|
||||
return nn.Linear(self.embed_dim, 1)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(
|
||||
self,
|
||||
input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType,
|
||||
) -> (TensorType, List[TensorType]):
|
||||
# True No-op forward method.
|
||||
# TODO: Support image observation inputs
|
||||
return input_dict["obs"], state
|
||||
|
||||
def get_prediction(
|
||||
self,
|
||||
model_out: TensorType,
|
||||
input_dict: SampleBatch,
|
||||
return_attentions: bool = False,
|
||||
) -> Dict[str, TensorType]:
|
||||
"""Computes the output of a forward pass of the decision transformer.
|
||||
|
||||
Args:
|
||||
model_out: output observation tensor from the base model, [B, T, obs_dim].
|
||||
input_dict: a SampleBatch containing
|
||||
RETURNS_TO_GO: [B, T (or T + 1), 1] of returns to go values.
|
||||
ACTIONS: [B, T, action_dim] of actions.
|
||||
T: [B, T] of timesteps.
|
||||
ATTENTION_MASKS: [B, T] of attention masks.
|
||||
return_attentions: Whether to return the attention tensors from the
|
||||
transformer or not.
|
||||
|
||||
Returns:
|
||||
A dictionary with keys and values:
|
||||
ACTIONS: [B, T, action_dim] of predicted actions.
|
||||
if return_attentions:
|
||||
"attentions": List of attentions tensors from the transformer.
|
||||
if model_config["use_obs_output"].
|
||||
OBS: [B, T, obs_dim] of predicted observations.
|
||||
if model_config["use_return_output"].
|
||||
RETURNS_to_GO: [B, T, 1] of predicted returns to go.
|
||||
"""
|
||||
B, T, *_ = model_out.shape
|
||||
|
||||
obs_embeds = self.obs_encoder(model_out)
|
||||
actions_embeds = self.action_encoder(input_dict[SampleBatch.ACTIONS])
|
||||
# Note: rtg might have an extra element at the end for targets
|
||||
# During training rtg will have T + 1 for its time dimension to get the
|
||||
# rtg regression target. During evaluation/inference rtg will have T for
|
||||
# its time dimension as we don't need to call get_targets.
|
||||
returns_embeds = self.return_encoder(
|
||||
input_dict[SampleBatch.RETURNS_TO_GO][:, :T, :]
|
||||
)
|
||||
timestep_embeds = self.position_encoder(input_dict[SampleBatch.T])
|
||||
|
||||
obs_embeds = obs_embeds + timestep_embeds
|
||||
actions_embeds = actions_embeds + timestep_embeds
|
||||
returns_embeds = returns_embeds + timestep_embeds
|
||||
|
||||
# This makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
|
||||
stacked_inputs = torch.stack(
|
||||
(returns_embeds, obs_embeds, actions_embeds), dim=2
|
||||
).reshape(B, 3 * T, self.embed_dim)
|
||||
|
||||
attention_masks = input_dict[SampleBatch.ATTENTION_MASKS]
|
||||
stacked_attention_masks = torch.stack(
|
||||
(attention_masks, attention_masks, attention_masks), dim=2
|
||||
).reshape(B, 3 * T)
|
||||
|
||||
# forward the transformer model
|
||||
output_embeds = self.transformer(
|
||||
stacked_inputs,
|
||||
attention_masks=stacked_attention_masks,
|
||||
return_attentions=return_attentions,
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
if return_attentions:
|
||||
output_embeds, attentions = output_embeds
|
||||
outputs["attentions"] = attentions
|
||||
|
||||
# compute output heads
|
||||
outputs[SampleBatch.ACTIONS] = self.action_head(output_embeds[:, 1::3, :])
|
||||
if self.model_config["use_obs_output"]:
|
||||
outputs[SampleBatch.OBS] = self.obs_head(output_embeds[:, 0::3, :])
|
||||
if self.model_config["use_return_output"]:
|
||||
outputs[SampleBatch.RETURNS_TO_GO] = self.return_head(
|
||||
output_embeds[:, 2::3, :]
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def get_targets(
|
||||
self, model_out: TensorType, input_dict: SampleBatch
|
||||
) -> Dict[str, TensorType]:
|
||||
"""Compute the target predictions for a given input_dict.
|
||||
|
||||
Args:
|
||||
model_out: output observation tensor from the base model, [B, T, obs_dim].
|
||||
input_dict: a SampleBatch containing
|
||||
RETURNS_TO_GO: [B, T + 1, 1] of returns to go values.
|
||||
ACTIONS: [B, T, action_dim] of actions.
|
||||
T: [B, T] of timesteps.
|
||||
ATTENTION_MASKS: [B, T] of attention masks.
|
||||
|
||||
Returns:
|
||||
A dictionary with keys and values:
|
||||
ACTIONS: [B, T, action_dim] of target actions.
|
||||
if model_config["use_obs_output"]
|
||||
OBS: [B, T, obs_dim] of target observations.
|
||||
if model_config["use_return_output"]
|
||||
RETURNS_to_GO: [B, T, 1] of target returns to go.
|
||||
"""
|
||||
targets = {SampleBatch.ACTIONS: input_dict[SampleBatch.ACTIONS].detach()}
|
||||
if self.model_config["use_obs_output"]:
|
||||
targets[SampleBatch.OBS] = model_out.detach()
|
||||
if self.model_config["use_return_output"]:
|
||||
targets[SampleBatch.RETURNS_TO_GO] = input_dict[SampleBatch.RETURNS_TO_GO][
|
||||
:, 1:, :
|
||||
].detach()
|
||||
|
||||
return targets
|
301
rllib/algorithms/dt/tests/test_dt_model.py
Normal file
301
rllib/algorithms/dt/tests/test_dt_model.py
Normal file
|
@ -0,0 +1,301 @@
|
|||
import unittest
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
||||
from ray.rllib.algorithms.dt.dt_torch_model import DTTorchModel
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def _assert_outputs_equal(outputs):
|
||||
for i in range(1, len(outputs)):
|
||||
for key in outputs[0].keys():
|
||||
assert np.allclose(
|
||||
outputs[0][key], outputs[i][key]
|
||||
), "outputs are different but they shouldn't be."
|
||||
|
||||
|
||||
def _assert_outputs_not_equal(outputs):
|
||||
for i in range(1, len(outputs)):
|
||||
for key in outputs[0].keys():
|
||||
assert not np.allclose(
|
||||
outputs[0][key], outputs[i][key]
|
||||
), "some outputs are the same but they shouldn't be."
|
||||
|
||||
|
||||
def _generate_input_dict(B, T, obs_space, action_space):
|
||||
"""Generate input_dict that has completely fake values."""
|
||||
# generate deterministic inputs
|
||||
# obs
|
||||
obs = np.arange(B * T * obs_space.shape[0], dtype=np.float32).reshape(
|
||||
(B, T, obs_space.shape[0])
|
||||
)
|
||||
# actions
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
act = np.arange(B * T * action_space.shape[0], dtype=np.float32).reshape(
|
||||
(B, T, action_space.shape[0])
|
||||
)
|
||||
else:
|
||||
act = np.mod(np.arange(B * T, dtype=np.int32).reshape((B, T)), action_space.n)
|
||||
# returns to go
|
||||
rtg = np.arange(B * (T + 1), dtype=np.float32).reshape((B, T + 1, 1))
|
||||
# timesteps
|
||||
timesteps = np.stack([np.arange(T, dtype=np.int32) for _ in range(B)], axis=0)
|
||||
# attention mask
|
||||
mask = np.ones((B, T), dtype=np.float32)
|
||||
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.OBS: obs,
|
||||
SampleBatch.ACTIONS: act,
|
||||
SampleBatch.RETURNS_TO_GO: rtg,
|
||||
SampleBatch.T: timesteps,
|
||||
SampleBatch.ATTENTION_MASKS: mask,
|
||||
}
|
||||
)
|
||||
input_dict = convert_to_torch_tensor(input_dict)
|
||||
return input_dict
|
||||
|
||||
|
||||
class TestDTModel(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_torch_model_init(self):
|
||||
"""Test models are initialized properly"""
|
||||
model_config = {
|
||||
"embed_dim": 32,
|
||||
"num_layers": 2,
|
||||
"max_seq_len": 4,
|
||||
"max_ep_len": 10,
|
||||
"num_heads": 2,
|
||||
"embed_pdrop": 0.1,
|
||||
"resid_pdrop": 0.1,
|
||||
"attn_pdrop": 0.1,
|
||||
"use_obs_output": False,
|
||||
"use_return_output": False,
|
||||
}
|
||||
|
||||
num_outputs = 2
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(num_outputs,))
|
||||
|
||||
action_dim = 5
|
||||
action_spaces = [
|
||||
gym.spaces.Box(-1.0, 1.0, shape=(action_dim,)),
|
||||
gym.spaces.Discrete(action_dim),
|
||||
]
|
||||
|
||||
B, T = 3, 4
|
||||
|
||||
for action_space in action_spaces:
|
||||
# Generate input dict.
|
||||
input_dict = _generate_input_dict(B, T, observation_space, action_space)
|
||||
|
||||
# Do random initialization a few times and make sure outputs are different
|
||||
outputs = []
|
||||
for _ in range(10):
|
||||
model = DTTorchModel(
|
||||
observation_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
"model",
|
||||
)
|
||||
# so dropout is not in effect
|
||||
model.eval()
|
||||
model_out, _ = model(input_dict)
|
||||
output = model.get_prediction(model_out, input_dict)
|
||||
outputs.append(convert_to_numpy(output))
|
||||
_assert_outputs_not_equal(outputs)
|
||||
|
||||
# Initialize once and make sure dropout is working
|
||||
model = DTTorchModel(
|
||||
observation_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
"model",
|
||||
)
|
||||
|
||||
# Dropout should make outputs different in training mode
|
||||
model.train()
|
||||
outputs = []
|
||||
for _ in range(10):
|
||||
model_out, _ = model(input_dict)
|
||||
output = model.get_prediction(model_out, input_dict)
|
||||
outputs.append(convert_to_numpy(output))
|
||||
_assert_outputs_not_equal(outputs)
|
||||
|
||||
# Dropout should make outputs the same in eval mode
|
||||
model.eval()
|
||||
outputs = []
|
||||
for _ in range(10):
|
||||
model_out, _ = model(input_dict)
|
||||
output = model.get_prediction(model_out, input_dict)
|
||||
outputs.append(convert_to_numpy(output))
|
||||
_assert_outputs_equal(outputs)
|
||||
|
||||
def test_torch_model_prediction_target(self):
|
||||
"""Test the get_prediction and get_targets function."""
|
||||
model_config = {
|
||||
"embed_dim": 16,
|
||||
"num_layers": 3,
|
||||
"max_seq_len": 3,
|
||||
"max_ep_len": 9,
|
||||
"num_heads": 1,
|
||||
"embed_pdrop": 0.2,
|
||||
"resid_pdrop": 0.2,
|
||||
"attn_pdrop": 0.2,
|
||||
"use_obs_output": True,
|
||||
"use_return_output": True,
|
||||
}
|
||||
|
||||
num_outputs = 5
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(num_outputs,))
|
||||
|
||||
action_dim = 2
|
||||
action_spaces = [
|
||||
gym.spaces.Box(-1.0, 1.0, shape=(action_dim,)),
|
||||
gym.spaces.Discrete(action_dim),
|
||||
]
|
||||
|
||||
B, T = 2, 3
|
||||
|
||||
for action_space in action_spaces:
|
||||
# Generate input dict.
|
||||
input_dict = _generate_input_dict(B, T, observation_space, action_space)
|
||||
|
||||
# Make model and forward pass.
|
||||
model = DTTorchModel(
|
||||
observation_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
"model",
|
||||
)
|
||||
model_out, _ = model(input_dict)
|
||||
preds = model.get_prediction(model_out, input_dict)
|
||||
target = model.get_targets(model_out, input_dict)
|
||||
|
||||
preds = convert_to_numpy(preds)
|
||||
target = convert_to_numpy(target)
|
||||
|
||||
# Test the content and shape of output and target
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
# test preds shape
|
||||
self.assertEqual(preds[SampleBatch.ACTIONS].shape, (B, T, action_dim))
|
||||
# test target shape and content
|
||||
self.assertEqual(target[SampleBatch.ACTIONS].shape, (B, T, action_dim))
|
||||
assert np.allclose(
|
||||
target[SampleBatch.ACTIONS],
|
||||
input_dict[SampleBatch.ACTIONS],
|
||||
)
|
||||
else:
|
||||
# test preds shape
|
||||
self.assertEqual(preds[SampleBatch.ACTIONS].shape, (B, T, action_dim))
|
||||
# test target shape and content
|
||||
self.assertEqual(target[SampleBatch.ACTIONS].shape, (B, T))
|
||||
assert np.allclose(
|
||||
target[SampleBatch.ACTIONS],
|
||||
input_dict[SampleBatch.ACTIONS],
|
||||
)
|
||||
|
||||
# test preds shape
|
||||
self.assertEqual(preds[SampleBatch.OBS].shape, (B, T, num_outputs))
|
||||
# test target shape and content
|
||||
self.assertEqual(target[SampleBatch.OBS].shape, (B, T, num_outputs))
|
||||
assert np.allclose(
|
||||
target[SampleBatch.OBS],
|
||||
input_dict[SampleBatch.OBS],
|
||||
)
|
||||
|
||||
# test preds shape
|
||||
self.assertEqual(preds[SampleBatch.RETURNS_TO_GO].shape, (B, T, 1))
|
||||
# test target shape and content
|
||||
self.assertEqual(target[SampleBatch.RETURNS_TO_GO].shape, (B, T, 1))
|
||||
assert np.allclose(
|
||||
target[SampleBatch.RETURNS_TO_GO],
|
||||
input_dict[SampleBatch.RETURNS_TO_GO][:, 1:, :],
|
||||
)
|
||||
|
||||
def test_causal_masking(self):
|
||||
"""Test that the transformer model' causal masking works."""
|
||||
model_config = {
|
||||
"embed_dim": 16,
|
||||
"num_layers": 2,
|
||||
"max_seq_len": 4,
|
||||
"max_ep_len": 10,
|
||||
"num_heads": 2,
|
||||
"embed_pdrop": 0,
|
||||
"resid_pdrop": 0,
|
||||
"attn_pdrop": 0,
|
||||
"use_obs_output": True,
|
||||
"use_return_output": True,
|
||||
}
|
||||
|
||||
observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,))
|
||||
action_space = gym.spaces.Box(-1.0, 1.0, shape=(2,))
|
||||
B = 2
|
||||
T = model_config["max_seq_len"]
|
||||
|
||||
# Generate input dict.
|
||||
input_dict = _generate_input_dict(B, T, observation_space, action_space)
|
||||
|
||||
# make model and forward with attention
|
||||
model = DTTorchModel(
|
||||
observation_space,
|
||||
action_space,
|
||||
4,
|
||||
model_config,
|
||||
"model",
|
||||
)
|
||||
model_out, _ = model(input_dict)
|
||||
preds = model.get_prediction(model_out, input_dict, return_attentions=True)
|
||||
preds = convert_to_numpy(preds)
|
||||
|
||||
# test properties of attentions
|
||||
attentions = preds["attentions"]
|
||||
self.assertEqual(
|
||||
len(attentions),
|
||||
model_config["num_layers"],
|
||||
"there should as many attention tensors as layers.",
|
||||
)
|
||||
|
||||
# used to select the causal padded element of each attention tensor
|
||||
select_mask = np.triu(np.ones((3 * T, 3 * T), dtype=np.bool), k=1)
|
||||
select_mask = np.tile(select_mask, (B, model_config["num_heads"], 1, 1))
|
||||
|
||||
for attention in attentions:
|
||||
# check shape
|
||||
self.assertEqual(
|
||||
attention.shape, (B, model_config["num_heads"], T * 3, T * 3)
|
||||
)
|
||||
# check the upper triangular masking
|
||||
assert np.allclose(
|
||||
attention[select_mask], 0.0
|
||||
), "masked elements should be zero."
|
||||
# check that the non-masked elements have non 0 scores
|
||||
# Note: it is very unlikely that randomly initialized weights will make
|
||||
# one of the scores be 0, as these scores are probabilities.
|
||||
assert not np.any(
|
||||
np.isclose(attention[np.logical_not(select_mask)], 0.0)
|
||||
), "non masked elements should be nonzero."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
296
rllib/models/torch/mingpt.py
Normal file
296
rllib/models/torch/mingpt.py
Normal file
|
@ -0,0 +1,296 @@
|
|||
# LICENSE: MIT
|
||||
"""
|
||||
Adapted from https://github.com/karpathy/minGPT
|
||||
|
||||
Full definition of a GPT Language Model, all of it in this single file.
|
||||
References:
|
||||
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
||||
https://github.com/openai/gpt-2/blob/master/src/model.py
|
||||
2) huggingface/transformers PyTorch implementation:
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers
|
||||
/models/gpt2/modeling_gpt2.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
# block size must be provided
|
||||
block_size: int
|
||||
|
||||
# transformer config
|
||||
n_layer: int = 12
|
||||
n_head: int = 12
|
||||
n_embed: int = 768
|
||||
|
||||
# dropout config
|
||||
embed_pdrop: float = 0.1
|
||||
resid_pdrop: float = 0.1
|
||||
attn_pdrop: float = 0.1
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT
|
||||
repo (identical to OpenAI GPT).
|
||||
Reference: Gaussian Error Linear Units (GELU) paper:
|
||||
https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (
|
||||
1.0
|
||||
+ torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""
|
||||
Vanilla multi-head masked self-attention layer with a projection at the end.
|
||||
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
||||
explicit implementation here to show that there is nothing too scary here.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GPTConfig):
|
||||
super().__init__()
|
||||
assert config.n_embed % config.n_head == 0
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(config.n_embed, config.n_embed)
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
# causal mask to ensure that attention is only applied to the left
|
||||
# in the input sequence
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
||||
1, 1, config.block_size, config.block_size
|
||||
),
|
||||
)
|
||||
self.n_head = config.n_head
|
||||
self.n_embed = config.n_embed
|
||||
|
||||
def forward(self, x, attention_masks=None):
|
||||
# batch size, sequence length, embedding dimensionality (n_embed)
|
||||
B, T, C = x.size()
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head
|
||||
# forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
|
||||
# (B, nh, T, hs)
|
||||
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
||||
# (B, nh, T, hs)
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
||||
# (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
||||
|
||||
# causal self-attention; Self-attend:
|
||||
# (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
||||
if attention_masks is not None:
|
||||
att = att + attention_masks
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
# re-assemble all head outputs side by side
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y, att
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""an unassuming Transformer block"""
|
||||
|
||||
def __init__(self, config: GPTConfig):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.n_embed)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embed)
|
||||
self.mlp = nn.ModuleDict(
|
||||
dict(
|
||||
c_fc=nn.Linear(config.n_embed, 4 * config.n_embed),
|
||||
c_proj=nn.Linear(4 * config.n_embed, config.n_embed),
|
||||
act=NewGELU(),
|
||||
dropout=nn.Dropout(config.resid_pdrop),
|
||||
)
|
||||
)
|
||||
m = self.mlp
|
||||
# MLP forward
|
||||
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
|
||||
|
||||
def forward(self, x, attention_masks=None):
|
||||
x_att, att = self.attn(self.ln_1(x), attention_masks=attention_masks)
|
||||
x = x + x_att
|
||||
x = x + self.mlpf(self.ln_2(x))
|
||||
return x, att
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def configure_gpt_optimizer(
|
||||
model: nn.Module,
|
||||
learning_rate: float,
|
||||
weight_decay: float,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
**kwargs,
|
||||
) -> torch.optim.Optimizer:
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is
|
||||
being very defensive: We are separating out all parameters of the model
|
||||
into two buckets: those that will experience weight decay for regularization
|
||||
and those that won't (biases, and layernorm/embedding weights). We are then
|
||||
returning the PyTorch optimizer object.
|
||||
"""
|
||||
|
||||
# separate out all parameters to those that will and won't experience
|
||||
# regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_w_modules = (torch.nn.Linear,)
|
||||
blacklist_w_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in model.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
|
||||
# random note: because named_modules and named_parameters are
|
||||
# recursive we will see the same tensors p many many times. but
|
||||
# doing it this way allows us to know which parent module any
|
||||
# tensor p belongs to...
|
||||
if pn.endswith("bias"):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, whitelist_w_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith("weight") and isinstance(m, blacklist_w_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in model.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f"parameters {str(param_dict.keys() - union_params)} were not "
|
||||
f"separated into either decay/no_decay set!"
|
||||
)
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(decay)],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **kwargs)
|
||||
return optimizer
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class GPT(nn.Module):
|
||||
"""GPT Transformer Model"""
|
||||
|
||||
def __init__(self, config: GPTConfig):
|
||||
super().__init__()
|
||||
assert config.block_size is not None
|
||||
self.block_size = config.block_size
|
||||
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
drop=nn.Dropout(config.embed_pdrop),
|
||||
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
||||
ln_f=nn.LayerNorm(config.n_embed),
|
||||
)
|
||||
)
|
||||
|
||||
# init all weights, and apply a special scaled init to the residual
|
||||
# projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith("c_proj.weight"):
|
||||
torch.nn.init.normal_(
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
|
||||
)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
|
||||
def forward(self, input_embeds, attention_masks=None, return_attentions=False):
|
||||
"""
|
||||
input_embeds: [batch_size x seq_len x n_embed]
|
||||
attention_masks: [batch_size x seq_len], 0 don't attend, 1 attend
|
||||
"""
|
||||
B, T, C = input_embeds.size()
|
||||
assert T <= self.block_size, (
|
||||
f"Cannot forward sequence of length {T}, "
|
||||
f"block size is only {self.block_size}"
|
||||
)
|
||||
|
||||
if attention_masks is not None:
|
||||
_B, _T = attention_masks.size()
|
||||
assert _B == B and _T == T
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_len]
|
||||
# So we can broadcast to
|
||||
# [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular
|
||||
# masking of causal attention used in OpenAI GPT, we just need
|
||||
# to prepare the broadcast dimension here.
|
||||
attention_masks = attention_masks[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend
|
||||
# and 0.0 for masked positions, this operation will create a
|
||||
# tensor which is 0.0 for positions we want to attend and -inf
|
||||
# for masked positions. Since we are adding it to the raw scores
|
||||
# before the softmax, this is effectively the same as removing
|
||||
# these entirely.
|
||||
attention_masks = attention_masks.to(dtype=input_embeds.dtype)
|
||||
attention_masks = (1.0 - attention_masks) * -1e9
|
||||
|
||||
# forward the GPT model itself
|
||||
x = self.transformer.drop(input_embeds)
|
||||
|
||||
atts = []
|
||||
for block in self.transformer.h:
|
||||
x, att = block(x, attention_masks=attention_masks)
|
||||
atts.append(att)
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
if return_attentions:
|
||||
return x, atts
|
||||
else:
|
||||
return x
|
Loading…
Add table
Reference in a new issue