diff --git a/python/ray/rllib/agents/qmix/qmix_policy_graph.py b/python/ray/rllib/agents/qmix/qmix_policy_graph.py index 5cd04d5ad..b7c9a7ad8 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy_graph.py +++ b/python/ray/rllib/agents/qmix/qmix_policy_graph.py @@ -46,16 +46,19 @@ class QMixLoss(nn.Module): self.double_q = double_q self.gamma = gamma - def forward(self, rewards, actions, terminated, mask, obs, action_mask): + def forward(self, rewards, actions, terminated, mask, obs, next_obs, + action_mask, next_action_mask): """Forward pass of the loss. Arguments: - rewards: Tensor of shape [B, T-1, n_agents] - actions: Tensor of shape [B, T-1, n_agents] - terminated: Tensor of shape [B, T-1, n_agents] - mask: Tensor of shape [B, T-1, n_agents] + rewards: Tensor of shape [B, T, n_agents] + actions: Tensor of shape [B, T, n_agents] + terminated: Tensor of shape [B, T, n_agents] + mask: Tensor of shape [B, T, n_agents] obs: Tensor of shape [B, T, n_agents, obs_size] + next_obs: Tensor of shape [B, T, n_agents, obs_size] action_mask: Tensor of shape [B, T, n_agents, n_actions] + next_action_mask: Tensor of shape [B, T, n_agents, n_actions] """ B, T = obs.size(0), obs.size(1) @@ -68,9 +71,9 @@ class QMixLoss(nn.Module): mac_out.append(q) mac_out = th.stack(mac_out, dim=1) # Concat over time - # Pick the Q-Values for the actions taken -> [B * n_agents, T-1] + # Pick the Q-Values for the actions taken -> [B * n_agents, T] chosen_action_qvals = th.gather( - mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3) + mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3) # Calculate the Q-Values necessary for the target target_mac_out = [] @@ -79,32 +82,37 @@ class QMixLoss(nn.Module): for s in self.target_model.state_init() ] for t in range(T): - target_q, target_h = _mac(self.target_model, obs[:, t], target_h) + target_q, target_h = _mac(self.target_model, next_obs[:, t], + target_h) target_mac_out.append(target_q) - - # We don't need the first timesteps Q-Value estimate for targets - target_mac_out = th.stack( - target_mac_out[1:], dim=1) # Concat across time + target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time # Mask out unavailable actions - target_mac_out[action_mask[:, 1:] == 0] = -9999999 + ignore_action = (next_action_mask == 0) & (mask == 1).unsqueeze(-1) + target_mac_out[ignore_action] = -np.inf # Max over target Q-Values if self.double_q: # Get actions that maximise live Q (for double q-learning) - mac_out[action_mask == 0] = -9999999 - cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] + ignore_action = (action_mask == 0) & (mask == 1).unsqueeze(-1) + mac_out = mac_out.clone() # issue 4742 + mac_out[ignore_action] = -np.inf + cur_max_actions = mac_out.max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] + assert target_max_qvals.min().item() != -np.inf, \ + "target_max_qvals contains a masked action; \ + there may be a state with no valid actions." + # Mix if self.mixer is not None: # TODO(ekl) add support for handling global state? This is just # treating the stacked agent obs as the state. - chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1]) - target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:]) + chosen_action_qvals = self.mixer(chosen_action_qvals, obs) + target_max_qvals = self.target_mixer(target_max_qvals, next_obs) # Calculate 1-step Q-Learning targets targets = rewards + self.gamma * (1 - terminated) * target_max_qvals @@ -239,48 +247,53 @@ class QMixPolicyGraph(PolicyGraph): def learn_on_batch(self, samples): obs_batch, action_mask = self._unpack_observation( samples[SampleBatch.CUR_OBS]) + next_obs_batch, next_action_mask = self._unpack_observation( + samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) # These will be padded to shape [B * T, ...] - [rew, action_mask, act, dones, obs], initial_states, seq_lens = \ + [rew, action_mask, next_action_mask, act, dones, obs, next_obs], \ + initial_states, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], [ - group_rewards, action_mask, samples[SampleBatch.ACTIONS], - samples[SampleBatch.DONES], obs_batch + group_rewards, action_mask, next_action_mask, + samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], + obs_batch, next_obs_batch ], [samples["state_in_{}".format(k)] for k in range(len(self.get_initial_state()))], max_seq_len=self.config["model"]["max_seq_len"], - dynamic_max=True, - _extra_padding=1) - # TODO(ekl) adding 1 extra unit of padding here, since otherwise we - # lose the terminating reward and the Q-values will be unanchored! - B, T = len(seq_lens), max(seq_lens) + 1 + dynamic_max=True) + B, T = len(seq_lens), max(seq_lens) def to_batches(arr): new_shape = [B, T] + list(arr.shape[1:]) return th.from_numpy(np.reshape(arr, new_shape)) - rewards = to_batches(rew)[:, :-1].float() - actions = to_batches(act)[:, :-1].long() + rewards = to_batches(rew).float() + actions = to_batches(act).long() obs = to_batches(obs).reshape([B, T, self.n_agents, self.obs_size]).float() action_mask = to_batches(action_mask) + next_obs = to_batches(next_obs).reshape( + [B, T, self.n_agents, self.obs_size]).float() + next_action_mask = to_batches(next_action_mask) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand( - B, T, self.n_agents)[:, :-1] + B, T, self.n_agents) + + # Create mask for where index is < unpadded sequence length filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) < np.expand_dims(seq_lens, 1)).astype(np.float32) - mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, - self.n_agents)[:, :-1] - mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) + mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = \ - self.loss(rewards, actions, terminated, mask, obs, action_mask) + self.loss(rewards, actions, terminated, mask, obs, + next_obs, action_mask, next_action_mask) # Optimise self.optimiser.zero_grad()