diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 1b5c75c5d..af8f08b53 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -60,46 +60,39 @@ class QMixLoss(nn.Module): next_action_mask: Tensor of shape [B, T, n_agents, n_actions] """ - B, T = obs.size(0), obs.size(1) - # Calculate estimated Q-Values - mac_out = [] - h = [ - s.expand([B, self.n_agents, -1]) - for s in self.model.get_initial_state() - ] - for t in range(T): - q, h = _mac(self.model, obs[:, t], h) - mac_out.append(q) - mac_out = th.stack(mac_out, dim=1) # Concat over time + mac_out = _unroll_mac(self.model, obs) # Pick the Q-Values for the actions taken -> [B * n_agents, T] chosen_action_qvals = th.gather( mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3) # Calculate the Q-Values necessary for the target - target_mac_out = [] - target_h = [ - s.expand([B, self.n_agents, -1]) - for s in self.target_model.get_initial_state() - ] - for t in range(T): - target_q, target_h = _mac(self.target_model, next_obs[:, t], - target_h) - target_mac_out.append(target_q) - target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time + target_mac_out = _unroll_mac(self.target_model, next_obs) - # Mask out unavailable actions - ignore_action = (next_action_mask == 0) & (mask == 1).unsqueeze(-1) - target_mac_out[ignore_action] = -np.inf + # Mask out unavailable actions for the t+1 step + ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1) + target_mac_out[ignore_action_tp1] = -np.inf # Max over target Q-Values if self.double_q: - # Get actions that maximise live Q (for double q-learning) - ignore_action = (action_mask == 0) & (mask == 1).unsqueeze(-1) - mac_out = mac_out.clone() # issue 4742 - mac_out[ignore_action] = -np.inf - cur_max_actions = mac_out.max(dim=3, keepdim=True)[1] + # Double Q learning computes the target Q values by selecting the + # t+1 timestep action according to the "policy" neural network and + # then estimating the Q-value of that action with the "target" + # neural network + + # Compute the t+1 Q-values to be used in action selection + # using next_obs + mac_out_tp1 = _unroll_mac(self.model, next_obs) + + # mask out unallowed actions + mac_out_tp1[ignore_action_tp1] = -np.inf + + # obtain best actions at t+1 according to policy NN + cur_max_actions = mac_out_tp1.max(dim=3, keepdim=True)[1] + + # use the target network to estimate the Q-values of policy + # network's selected actions target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: @@ -157,7 +150,7 @@ class QMixTorchPolicy(Policy): agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) - if space_keys != {"obs", "action_mask"}: + if not {"obs", "action_mask"}.issubset(space_keys): raise ValueError( "Dict obs space for agent must have keyset " "['obs', 'action_mask'], got {}".format(space_keys)) @@ -448,3 +441,19 @@ def _mac(model, obs, h): q_flat, h_flat = model({"obs": obs_flat}, h_flat, None) return q_flat.reshape( [B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat] + + +def _unroll_mac(model, obs_tensor): + """Computes the estimated Q values for an entire trajectory batch""" + B = obs_tensor.size(0) + T = obs_tensor.size(1) + n_agents = obs_tensor.size(2) + + mac_out = [] + h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()] + for t in range(T): + q, h = _mac(model, obs_tensor[:, t], h) + mac_out.append(q) + mac_out = th.stack(mac_out, dim=1) # Concat over time + + return mac_out