ray/rllib/examples/models/neural_computer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

247 lines
8.4 KiB
Python
Raw Normal View History

from collections import OrderedDict
import gym
from typing import Union, Dict, List, Tuple
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
try:
from dnc import DNC
except ModuleNotFoundError:
print("dnc module not found. Did you forget to 'pip install dnc'?")
raise
torch, nn = try_import_torch()
class DNCMemory(TorchModelV2, nn.Module):
"""Differentiable Neural Computer wrapper around ixaxaar's DNC implementation,
see https://github.com/ixaxaar/pytorch-dnc"""
DEFAULT_CONFIG = {
"dnc_model": DNC,
# Number of controller hidden layers
"num_hidden_layers": 1,
# Number of weights per controller hidden layer
"hidden_size": 64,
# Number of LSTM units
"num_layers": 1,
# Number of read heads, i.e. how many addrs are read at once
"read_heads": 4,
# Number of memory cells in the controller
"nr_cells": 32,
# Size of each cell
"cell_size": 16,
# LSTM activation function
"nonlinearity": "tanh",
# Observation goes through this torch.nn.Module before
# feeding to the DNC
"preprocessor": torch.nn.Sequential(torch.nn.Linear(64, 64), torch.nn.Tanh()),
# Input size to the preprocessor
"preprocessor_input_size": 64,
# The output size of the preprocessor
# and the input size of the dnc
"preprocessor_output_size": 64,
}
MEMORY_KEYS = [
"memory",
"link_matrix",
"precedence",
"read_weights",
"write_weights",
"usage_vector",
]
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
**custom_model_kwargs,
):
nn.Module.__init__(self)
super(DNCMemory, self).__init__(
obs_space, action_space, num_outputs, model_config, name
)
self.num_outputs = num_outputs
self.obs_dim = gym.spaces.utils.flatdim(obs_space)
self.act_dim = gym.spaces.utils.flatdim(action_space)
self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs)
assert (
self.cfg["num_layers"] == 1
), "num_layers != 1 has not been implemented yet"
self.cur_val = None
self.preprocessor = torch.nn.Sequential(
torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
self.cfg["preprocessor"],
)
self.logit_branch = SlimFC(
in_size=self.cfg["hidden_size"],
out_size=self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.value_branch = SlimFC(
in_size=self.cfg["hidden_size"],
out_size=1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self.dnc: Union[None, DNC] = None
def get_initial_state(self) -> List[TensorType]:
ctrl_hidden = [
torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
]
m = self.cfg["nr_cells"]
r = self.cfg["read_heads"]
w = self.cfg["cell_size"]
memory = [
torch.zeros(m, w), # memory
torch.zeros(1, m, m), # link_matrix
torch.zeros(1, m), # precedence
torch.zeros(r, m), # read_weights
torch.zeros(1, m), # write_weights
torch.zeros(m), # usage_vector
]
read_vecs = torch.zeros(w * r)
state = [*ctrl_hidden, read_vecs, *memory]
assert len(state) == 9
return state
def value_function(self) -> TensorType:
assert self.cur_val is not None, "must call forward() first"
return self.cur_val
def unpack_state(
self,
state: List[TensorType],
) -> Tuple[List[Tuple[TensorType, TensorType]], Dict[str, TensorType], TensorType]:
"""Given a list of tensors, reformat for self.dnc input"""
assert len(state) == 9, "Failed to verify unpacked state"
ctrl_hidden: List[Tuple[TensorType, TensorType]] = [
(
state[0].permute(1, 0, 2).contiguous(),
state[1].permute(1, 0, 2).contiguous(),
)
]
read_vecs: TensorType = state[2]
memory: List[TensorType] = state[3:]
memory_dict: OrderedDict[str, TensorType] = OrderedDict(
zip(self.MEMORY_KEYS, memory)
)
return ctrl_hidden, memory_dict, read_vecs
def pack_state(
self,
ctrl_hidden: List[Tuple[TensorType, TensorType]],
memory_dict: Dict[str, TensorType],
read_vecs: TensorType,
) -> List[TensorType]:
"""Given the dnc output, pack it into a list of tensors
for rllib state. Order is ctrl_hidden, read_vecs, memory_dict"""
state = []
ctrl_hidden = [
ctrl_hidden[0][0].permute(1, 0, 2),
ctrl_hidden[0][1].permute(1, 0, 2),
]
state += ctrl_hidden
assert len(state) == 2, "Failed to verify packed state"
state.append(read_vecs)
assert len(state) == 3, "Failed to verify packed state"
state += memory_dict.values()
assert len(state) == 9, "Failed to verify packed state"
return state
def validate_unpack(self, dnc_output, unpacked_state):
"""Ensure the unpacked state shapes match the DNC output"""
s_ctrl_hidden, s_memory_dict, s_read_vecs = unpacked_state
ctrl_hidden, memory_dict, read_vecs = dnc_output
for i in range(len(ctrl_hidden)):
for j in range(len(ctrl_hidden[i])):
assert s_ctrl_hidden[i][j].shape == ctrl_hidden[i][j].shape, (
"Controller state mismatch: got "
f"{s_ctrl_hidden[i][j].shape} should be "
f"{ctrl_hidden[i][j].shape}"
)
for k in memory_dict:
assert s_memory_dict[k].shape == memory_dict[k].shape, (
"Memory state mismatch at key "
f"{k}: got {s_memory_dict[k].shape} should be "
f"{memory_dict[k].shape}"
)
assert s_read_vecs.shape == read_vecs.shape, (
"Read state mismatch: got "
f"{s_read_vecs.shape} should be "
f"{read_vecs.shape}"
)
def build_dnc(self, device_idx: Union[int, None]) -> None:
self.dnc = self.cfg["dnc_model"](
input_size=self.cfg["preprocessor_output_size"],
hidden_size=self.cfg["hidden_size"],
num_layers=self.cfg["num_layers"],
num_hidden_layers=self.cfg["num_hidden_layers"],
read_heads=self.cfg["read_heads"],
cell_size=self.cfg["cell_size"],
nr_cells=self.cfg["nr_cells"],
nonlinearity=self.cfg["nonlinearity"],
gpu_id=device_idx,
)
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> Tuple[TensorType, List[TensorType]]:
flat = input_dict["obs_flat"]
# Batch and Time
# Forward expects outputs as [B, T, logits]
B = len(seq_lens)
T = flat.shape[0] // B
# Deconstruct batch into batch and time dimensions: [B, T, feats]
flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:]))
# First run
if self.dnc is None:
gpu_id = flat.device.index if flat.device.index is not None else -1
self.build_dnc(gpu_id)
hidden = (None, None, None)
else:
hidden = self.unpack_state(state) # type: ignore
# Run thru preprocessor before DNC
z = self.preprocessor(flat.reshape(B * T, self.obs_dim))
z = z.reshape(B, T, self.cfg["preprocessor_output_size"])
output, hidden = self.dnc(z, hidden)
packed_state = self.pack_state(*hidden)
# Compute action/value from output
logits = self.logit_branch(output.view(B * T, -1))
values = self.value_branch(output.view(B * T, -1))
self.cur_val = values.squeeze(1)
return logits, packed_state