from collections import OrderedDict import gym from typing import Union, Dict, List, Tuple from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.misc import SlimFC from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModelConfigDict, TensorType try: from dnc import DNC except ModuleNotFoundError: print("dnc module not found. Did you forget to 'pip install dnc'?") raise torch, nn = try_import_torch() class DNCMemory(TorchModelV2, nn.Module): """Differentiable Neural Computer wrapper around ixaxaar's DNC implementation, see https://github.com/ixaxaar/pytorch-dnc""" DEFAULT_CONFIG = { "dnc_model": DNC, # Number of controller hidden layers "num_hidden_layers": 1, # Number of weights per controller hidden layer "hidden_size": 64, # Number of LSTM units "num_layers": 1, # Number of read heads, i.e. how many addrs are read at once "read_heads": 4, # Number of memory cells in the controller "nr_cells": 32, # Size of each cell "cell_size": 16, # LSTM activation function "nonlinearity": "tanh", # Observation goes through this torch.nn.Module before # feeding to the DNC "preprocessor": torch.nn.Sequential( torch.nn.Linear(64, 64), torch.nn.Tanh()), # Input size to the preprocessor "preprocessor_input_size": 64, # The output size of the preprocessor # and the input size of the dnc "preprocessor_output_size": 64, } MEMORY_KEYS = [ "memory", "link_matrix", "precedence", "read_weights", "write_weights", "usage_vector", ] def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: ModelConfigDict, name: str, **custom_model_kwargs, ): nn.Module.__init__(self) super(DNCMemory, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.num_outputs = num_outputs self.obs_dim = gym.spaces.utils.flatdim(obs_space) self.act_dim = gym.spaces.utils.flatdim(action_space) self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs) assert (self.cfg["num_layers"] == 1 ), "num_layers != 1 has not been implemented yet" self.cur_val = None self.preprocessor = torch.nn.Sequential( torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]), self.cfg["preprocessor"], ) self.logit_branch = SlimFC( in_size=self.cfg["hidden_size"], out_size=self.num_outputs, activation_fn=None, initializer=torch.nn.init.xavier_uniform_, ) self.value_branch = SlimFC( in_size=self.cfg["hidden_size"], out_size=1, activation_fn=None, initializer=torch.nn.init.xavier_uniform_, ) self.dnc: Union[None, DNC] = None def get_initial_state(self) -> List[TensorType]: ctrl_hidden = [ torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]), torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]), ] m = self.cfg["nr_cells"] r = self.cfg["read_heads"] w = self.cfg["cell_size"] memory = [ torch.zeros(m, w), # memory torch.zeros(1, m, m), # link_matrix torch.zeros(1, m), # precedence torch.zeros(r, m), # read_weights torch.zeros(1, m), # write_weights torch.zeros(m), # usage_vector ] read_vecs = torch.zeros(w * r) state = [*ctrl_hidden, read_vecs, *memory] assert len(state) == 9 return state def value_function(self) -> TensorType: assert self.cur_val is not None, "must call forward() first" return self.cur_val def unpack_state( self, state: List[TensorType], ) -> Tuple[List[Tuple[TensorType, TensorType]], Dict[str, TensorType], TensorType]: """Given a list of tensors, reformat for self.dnc input""" assert len(state) == 9, "Failed to verify unpacked state" ctrl_hidden: List[Tuple[TensorType, TensorType]] = [( state[0].permute(1, 0, 2).contiguous(), state[1].permute(1, 0, 2).contiguous(), )] read_vecs: TensorType = state[2] memory: List[TensorType] = state[3:] memory_dict: OrderedDict[str, TensorType] = OrderedDict( zip(self.MEMORY_KEYS, memory)) return ctrl_hidden, memory_dict, read_vecs def pack_state( self, ctrl_hidden: List[Tuple[TensorType, TensorType]], memory_dict: Dict[str, TensorType], read_vecs: TensorType, ) -> List[TensorType]: """Given the dnc output, pack it into a list of tensors for rllib state. Order is ctrl_hidden, read_vecs, memory_dict""" state = [] ctrl_hidden = [ ctrl_hidden[0][0].permute(1, 0, 2), ctrl_hidden[0][1].permute(1, 0, 2), ] state += ctrl_hidden assert len(state) == 2, "Failed to verify packed state" state.append(read_vecs) assert len(state) == 3, "Failed to verify packed state" state += memory_dict.values() assert len(state) == 9, "Failed to verify packed state" return state def validate_unpack(self, dnc_output, unpacked_state): """Ensure the unpacked state shapes match the DNC output""" s_ctrl_hidden, s_memory_dict, s_read_vecs = unpacked_state ctrl_hidden, memory_dict, read_vecs = dnc_output for i in range(len(ctrl_hidden)): for j in range(len(ctrl_hidden[i])): assert s_ctrl_hidden[i][j].shape == ctrl_hidden[i][j].shape, ( "Controller state mismatch: got " f"{s_ctrl_hidden[i][j].shape} should be " f"{ctrl_hidden[i][j].shape}") for k in memory_dict: assert s_memory_dict[k].shape == memory_dict[k].shape, ( "Memory state mismatch at key " f"{k}: got {s_memory_dict[k].shape} should be " f"{memory_dict[k].shape}") assert s_read_vecs.shape == read_vecs.shape, ( "Read state mismatch: got " f"{s_read_vecs.shape} should be " f"{read_vecs.shape}") def build_dnc(self, device_idx: Union[int, None]) -> None: self.dnc = self.cfg["dnc_model"]( input_size=self.cfg["preprocessor_output_size"], hidden_size=self.cfg["hidden_size"], num_layers=self.cfg["num_layers"], num_hidden_layers=self.cfg["num_hidden_layers"], read_heads=self.cfg["read_heads"], cell_size=self.cfg["cell_size"], nr_cells=self.cfg["nr_cells"], nonlinearity=self.cfg["nonlinearity"], gpu_id=device_idx, ) def forward( self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType, ) -> Tuple[TensorType, List[TensorType]]: flat = input_dict["obs_flat"] # Batch and Time # Forward expects outputs as [B, T, logits] B = len(seq_lens) T = flat.shape[0] // B # Deconstruct batch into batch and time dimensions: [B, T, feats] flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:])) # First run if self.dnc is None: gpu_id = flat.device.index if flat.device.index is not None else -1 self.build_dnc(gpu_id) hidden = (None, None, None) else: hidden = self.unpack_state(state) # type: ignore # Run thru preprocessor before DNC z = self.preprocessor(flat.reshape(B * T, self.obs_dim)) z = z.reshape(B, T, self.cfg["preprocessor_output_size"]) output, hidden = self.dnc(z, hidden) packed_state = self.pack_state(*hidden) # Compute action/value from output logits = self.logit_branch(output.view(B * T, -1)) values = self.value_branch(output.view(B * T, -1)) self.cur_val = values.squeeze(1) return logits, packed_state