import numpy as np from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() class VDNMixer(nn.Module): def __init__(self): super(VDNMixer, self).__init__() def forward(self, agent_qs, batch): return torch.sum(agent_qs, dim=2, keepdim=True) class QMixer(nn.Module): def __init__(self, n_agents, state_shape, mixing_embed_dim): super(QMixer, self).__init__() self.n_agents = n_agents self.embed_dim = mixing_embed_dim self.state_dim = int(np.prod(state_shape)) self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) # State dependent bias for hidden layer self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) # V(s) instead of a bias for the last layers self.V = nn.Sequential( nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(), nn.Linear(self.embed_dim, 1), ) def forward(self, agent_qs, states): """Forward pass for the mixer. Args: agent_qs: Tensor of shape [B, T, n_agents, n_actions] states: Tensor of shape [B, T, state_dim] """ bs = agent_qs.size(0) states = states.reshape(-1, self.state_dim) agent_qs = agent_qs.view(-1, 1, self.n_agents) # First layer w1 = torch.abs(self.hyper_w_1(states)) b1 = self.hyper_b_1(states) w1 = w1.view(-1, self.n_agents, self.embed_dim) b1 = b1.view(-1, 1, self.embed_dim) hidden = nn.functional.elu(torch.bmm(agent_qs, w1) + b1) # Second layer w_final = torch.abs(self.hyper_w_final(states)) w_final = w_final.view(-1, self.embed_dim, 1) # State-dependent bias v = self.V(states).view(-1, 1, 1) # Compute final output y = torch.bmm(hidden, w_final) + v # Reshape and return q_tot = y.view(bs, -1, 1) return q_tot