mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
61 lines
2 KiB
Python
61 lines
2 KiB
Python
import numpy as np
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class VDNMixer(nn.Module):
|
|
def __init__(self):
|
|
super(VDNMixer, self).__init__()
|
|
|
|
def forward(self, agent_qs, batch):
|
|
return torch.sum(agent_qs, dim=2, keepdim=True)
|
|
|
|
|
|
class QMixer(nn.Module):
|
|
def __init__(self, n_agents, state_shape, mixing_embed_dim):
|
|
super(QMixer, self).__init__()
|
|
|
|
self.n_agents = n_agents
|
|
self.embed_dim = mixing_embed_dim
|
|
self.state_dim = int(np.prod(state_shape))
|
|
|
|
self.hyper_w_1 = nn.Linear(self.state_dim,
|
|
self.embed_dim * self.n_agents)
|
|
self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
|
|
|
|
# State dependent bias for hidden layer
|
|
self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
|
|
|
|
# V(s) instead of a bias for the last layers
|
|
self.V = nn.Sequential(
|
|
nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(),
|
|
nn.Linear(self.embed_dim, 1))
|
|
|
|
def forward(self, agent_qs, states):
|
|
"""Forward pass for the mixer.
|
|
|
|
Args:
|
|
agent_qs: Tensor of shape [B, T, n_agents, n_actions]
|
|
states: Tensor of shape [B, T, state_dim]
|
|
"""
|
|
bs = agent_qs.size(0)
|
|
states = states.reshape(-1, self.state_dim)
|
|
agent_qs = agent_qs.view(-1, 1, self.n_agents)
|
|
# First layer
|
|
w1 = torch.abs(self.hyper_w_1(states))
|
|
b1 = self.hyper_b_1(states)
|
|
w1 = w1.view(-1, self.n_agents, self.embed_dim)
|
|
b1 = b1.view(-1, 1, self.embed_dim)
|
|
hidden = nn.functional.elu(torch.bmm(agent_qs, w1) + b1)
|
|
# Second layer
|
|
w_final = torch.abs(self.hyper_w_final(states))
|
|
w_final = w_final.view(-1, self.embed_dim, 1)
|
|
# State-dependent bias
|
|
v = self.V(states).view(-1, 1, 1)
|
|
# Compute final output
|
|
y = torch.bmm(hidden, w_final) + v
|
|
# Reshape and return
|
|
q_tot = y.view(bs, -1, 1)
|
|
return q_tot
|