[rllib] Qmix off by 1 in double Q calculation (#5731)

* Qmix fix.

-Current version of double Q learning is incorrect; it selects actions
at timestep t instead of t+1 when computing the t+1 Q value.

* Allow extra obs dict keys

* Move Q-value-computing replay code to own function

* Run the autoformatter

* use better terms in comments ("policy" network instead of "live" network)
This commit is contained in:
Matthew A. Wright 2019-09-18 18:12:30 -07:00 committed by Eric Liang
parent 8903bcd0c3
commit 3131e1742d

View file

@ -60,46 +60,39 @@ class QMixLoss(nn.Module):
next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
"""
B, T = obs.size(0), obs.size(1)
# Calculate estimated Q-Values
mac_out = []
h = [
s.expand([B, self.n_agents, -1])
for s in self.model.get_initial_state()
]
for t in range(T):
q, h = _mac(self.model, obs[:, t], h)
mac_out.append(q)
mac_out = th.stack(mac_out, dim=1) # Concat over time
mac_out = _unroll_mac(self.model, obs)
# Pick the Q-Values for the actions taken -> [B * n_agents, T]
chosen_action_qvals = th.gather(
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
# Calculate the Q-Values necessary for the target
target_mac_out = []
target_h = [
s.expand([B, self.n_agents, -1])
for s in self.target_model.get_initial_state()
]
for t in range(T):
target_q, target_h = _mac(self.target_model, next_obs[:, t],
target_h)
target_mac_out.append(target_q)
target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time
target_mac_out = _unroll_mac(self.target_model, next_obs)
# Mask out unavailable actions
ignore_action = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
target_mac_out[ignore_action] = -np.inf
# Mask out unavailable actions for the t+1 step
ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
target_mac_out[ignore_action_tp1] = -np.inf
# Max over target Q-Values
if self.double_q:
# Get actions that maximise live Q (for double q-learning)
ignore_action = (action_mask == 0) & (mask == 1).unsqueeze(-1)
mac_out = mac_out.clone() # issue 4742
mac_out[ignore_action] = -np.inf
cur_max_actions = mac_out.max(dim=3, keepdim=True)[1]
# Double Q learning computes the target Q values by selecting the
# t+1 timestep action according to the "policy" neural network and
# then estimating the Q-value of that action with the "target"
# neural network
# Compute the t+1 Q-values to be used in action selection
# using next_obs
mac_out_tp1 = _unroll_mac(self.model, next_obs)
# mask out unallowed actions
mac_out_tp1[ignore_action_tp1] = -np.inf
# obtain best actions at t+1 according to policy NN
cur_max_actions = mac_out_tp1.max(dim=3, keepdim=True)[1]
# use the target network to estimate the Q-values of policy
# network's selected actions
target_max_qvals = th.gather(target_mac_out, 3,
cur_max_actions).squeeze(3)
else:
@ -157,7 +150,7 @@ class QMixTorchPolicy(Policy):
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, Dict):
space_keys = set(agent_obs_space.spaces.keys())
if space_keys != {"obs", "action_mask"}:
if not {"obs", "action_mask"}.issubset(space_keys):
raise ValueError(
"Dict obs space for agent must have keyset "
"['obs', 'action_mask'], got {}".format(space_keys))
@ -448,3 +441,19 @@ def _mac(model, obs, h):
q_flat, h_flat = model({"obs": obs_flat}, h_flat, None)
return q_flat.reshape(
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
def _unroll_mac(model, obs_tensor):
"""Computes the estimated Q values for an entire trajectory batch"""
B = obs_tensor.size(0)
T = obs_tensor.size(1)
n_agents = obs_tensor.size(2)
mac_out = []
h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()]
for t in range(T):
q, h = _mac(model, obs_tensor[:, t], h)
mac_out.append(q)
mac_out = th.stack(mac_out, dim=1) # Concat over time
return mac_out