mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Qmix off by 1 in double Q calculation (#5731)
* Qmix fix. -Current version of double Q learning is incorrect; it selects actions at timestep t instead of t+1 when computing the t+1 Q value. * Allow extra obs dict keys * Move Q-value-computing replay code to own function * Run the autoformatter * use better terms in comments ("policy" network instead of "live" network)
This commit is contained in:
parent
8903bcd0c3
commit
3131e1742d
1 changed files with 39 additions and 30 deletions
|
@ -60,46 +60,39 @@ class QMixLoss(nn.Module):
|
|||
next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
"""
|
||||
|
||||
B, T = obs.size(0), obs.size(1)
|
||||
|
||||
# Calculate estimated Q-Values
|
||||
mac_out = []
|
||||
h = [
|
||||
s.expand([B, self.n_agents, -1])
|
||||
for s in self.model.get_initial_state()
|
||||
]
|
||||
for t in range(T):
|
||||
q, h = _mac(self.model, obs[:, t], h)
|
||||
mac_out.append(q)
|
||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
||||
mac_out = _unroll_mac(self.model, obs)
|
||||
|
||||
# Pick the Q-Values for the actions taken -> [B * n_agents, T]
|
||||
chosen_action_qvals = th.gather(
|
||||
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
||||
|
||||
# Calculate the Q-Values necessary for the target
|
||||
target_mac_out = []
|
||||
target_h = [
|
||||
s.expand([B, self.n_agents, -1])
|
||||
for s in self.target_model.get_initial_state()
|
||||
]
|
||||
for t in range(T):
|
||||
target_q, target_h = _mac(self.target_model, next_obs[:, t],
|
||||
target_h)
|
||||
target_mac_out.append(target_q)
|
||||
target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time
|
||||
target_mac_out = _unroll_mac(self.target_model, next_obs)
|
||||
|
||||
# Mask out unavailable actions
|
||||
ignore_action = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
|
||||
target_mac_out[ignore_action] = -np.inf
|
||||
# Mask out unavailable actions for the t+1 step
|
||||
ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
|
||||
target_mac_out[ignore_action_tp1] = -np.inf
|
||||
|
||||
# Max over target Q-Values
|
||||
if self.double_q:
|
||||
# Get actions that maximise live Q (for double q-learning)
|
||||
ignore_action = (action_mask == 0) & (mask == 1).unsqueeze(-1)
|
||||
mac_out = mac_out.clone() # issue 4742
|
||||
mac_out[ignore_action] = -np.inf
|
||||
cur_max_actions = mac_out.max(dim=3, keepdim=True)[1]
|
||||
# Double Q learning computes the target Q values by selecting the
|
||||
# t+1 timestep action according to the "policy" neural network and
|
||||
# then estimating the Q-value of that action with the "target"
|
||||
# neural network
|
||||
|
||||
# Compute the t+1 Q-values to be used in action selection
|
||||
# using next_obs
|
||||
mac_out_tp1 = _unroll_mac(self.model, next_obs)
|
||||
|
||||
# mask out unallowed actions
|
||||
mac_out_tp1[ignore_action_tp1] = -np.inf
|
||||
|
||||
# obtain best actions at t+1 according to policy NN
|
||||
cur_max_actions = mac_out_tp1.max(dim=3, keepdim=True)[1]
|
||||
|
||||
# use the target network to estimate the Q-values of policy
|
||||
# network's selected actions
|
||||
target_max_qvals = th.gather(target_mac_out, 3,
|
||||
cur_max_actions).squeeze(3)
|
||||
else:
|
||||
|
@ -157,7 +150,7 @@ class QMixTorchPolicy(Policy):
|
|||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, Dict):
|
||||
space_keys = set(agent_obs_space.spaces.keys())
|
||||
if space_keys != {"obs", "action_mask"}:
|
||||
if not {"obs", "action_mask"}.issubset(space_keys):
|
||||
raise ValueError(
|
||||
"Dict obs space for agent must have keyset "
|
||||
"['obs', 'action_mask'], got {}".format(space_keys))
|
||||
|
@ -448,3 +441,19 @@ def _mac(model, obs, h):
|
|||
q_flat, h_flat = model({"obs": obs_flat}, h_flat, None)
|
||||
return q_flat.reshape(
|
||||
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
|
||||
|
||||
|
||||
def _unroll_mac(model, obs_tensor):
|
||||
"""Computes the estimated Q values for an entire trajectory batch"""
|
||||
B = obs_tensor.size(0)
|
||||
T = obs_tensor.size(1)
|
||||
n_agents = obs_tensor.size(2)
|
||||
|
||||
mac_out = []
|
||||
h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()]
|
||||
for t in range(T):
|
||||
q, h = _mac(model, obs_tensor[:, t], h)
|
||||
mac_out.append(q)
|
||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
||||
|
||||
return mac_out
|
||||
|
|
Loading…
Add table
Reference in a new issue