[rllib] Qmix padding patch (#4735)

* Qmix padding patch

* Update qmix_policy_graph.py

* lint errors

* more linting

* Update qmix_policy_graph.py
This commit is contained in:
Jacob Beck 2019-05-08 22:07:29 +01:00 committed by Eric Liang
parent edb8465910
commit 28496c8b50

View file

@ -46,16 +46,19 @@ class QMixLoss(nn.Module):
self.double_q = double_q
self.gamma = gamma
def forward(self, rewards, actions, terminated, mask, obs, action_mask):
def forward(self, rewards, actions, terminated, mask, obs, next_obs,
action_mask, next_action_mask):
"""Forward pass of the loss.
Arguments:
rewards: Tensor of shape [B, T-1, n_agents]
actions: Tensor of shape [B, T-1, n_agents]
terminated: Tensor of shape [B, T-1, n_agents]
mask: Tensor of shape [B, T-1, n_agents]
rewards: Tensor of shape [B, T, n_agents]
actions: Tensor of shape [B, T, n_agents]
terminated: Tensor of shape [B, T, n_agents]
mask: Tensor of shape [B, T, n_agents]
obs: Tensor of shape [B, T, n_agents, obs_size]
next_obs: Tensor of shape [B, T, n_agents, obs_size]
action_mask: Tensor of shape [B, T, n_agents, n_actions]
next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
"""
B, T = obs.size(0), obs.size(1)
@ -68,9 +71,9 @@ class QMixLoss(nn.Module):
mac_out.append(q)
mac_out = th.stack(mac_out, dim=1) # Concat over time
# Pick the Q-Values for the actions taken -> [B * n_agents, T-1]
# Pick the Q-Values for the actions taken -> [B * n_agents, T]
chosen_action_qvals = th.gather(
mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3)
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
# Calculate the Q-Values necessary for the target
target_mac_out = []
@ -79,32 +82,37 @@ class QMixLoss(nn.Module):
for s in self.target_model.state_init()
]
for t in range(T):
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
target_q, target_h = _mac(self.target_model, next_obs[:, t],
target_h)
target_mac_out.append(target_q)
# We don't need the first timesteps Q-Value estimate for targets
target_mac_out = th.stack(
target_mac_out[1:], dim=1) # Concat across time
target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time
# Mask out unavailable actions
target_mac_out[action_mask[:, 1:] == 0] = -9999999
ignore_action = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
target_mac_out[ignore_action] = -np.inf
# Max over target Q-Values
if self.double_q:
# Get actions that maximise live Q (for double q-learning)
mac_out[action_mask == 0] = -9999999
cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
ignore_action = (action_mask == 0) & (mask == 1).unsqueeze(-1)
mac_out = mac_out.clone() # issue 4742
mac_out[ignore_action] = -np.inf
cur_max_actions = mac_out.max(dim=3, keepdim=True)[1]
target_max_qvals = th.gather(target_mac_out, 3,
cur_max_actions).squeeze(3)
else:
target_max_qvals = target_mac_out.max(dim=3)[0]
assert target_max_qvals.min().item() != -np.inf, \
"target_max_qvals contains a masked action; \
there may be a state with no valid actions."
# Mix
if self.mixer is not None:
# TODO(ekl) add support for handling global state? This is just
# treating the stacked agent obs as the state.
chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1])
target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:])
chosen_action_qvals = self.mixer(chosen_action_qvals, obs)
target_max_qvals = self.target_mixer(target_max_qvals, next_obs)
# Calculate 1-step Q-Learning targets
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
@ -239,48 +247,53 @@ class QMixPolicyGraph(PolicyGraph):
def learn_on_batch(self, samples):
obs_batch, action_mask = self._unpack_observation(
samples[SampleBatch.CUR_OBS])
next_obs_batch, next_action_mask = self._unpack_observation(
samples[SampleBatch.NEXT_OBS])
group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])
# These will be padded to shape [B * T, ...]
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
[rew, action_mask, next_action_mask, act, dones, obs, next_obs], \
initial_states, seq_lens = \
chop_into_sequences(
samples[SampleBatch.EPS_ID],
samples[SampleBatch.UNROLL_ID],
samples[SampleBatch.AGENT_INDEX], [
group_rewards, action_mask, samples[SampleBatch.ACTIONS],
samples[SampleBatch.DONES], obs_batch
group_rewards, action_mask, next_action_mask,
samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
obs_batch, next_obs_batch
],
[samples["state_in_{}".format(k)]
for k in range(len(self.get_initial_state()))],
max_seq_len=self.config["model"]["max_seq_len"],
dynamic_max=True,
_extra_padding=1)
# TODO(ekl) adding 1 extra unit of padding here, since otherwise we
# lose the terminating reward and the Q-values will be unanchored!
B, T = len(seq_lens), max(seq_lens) + 1
dynamic_max=True)
B, T = len(seq_lens), max(seq_lens)
def to_batches(arr):
new_shape = [B, T] + list(arr.shape[1:])
return th.from_numpy(np.reshape(arr, new_shape))
rewards = to_batches(rew)[:, :-1].float()
actions = to_batches(act)[:, :-1].long()
rewards = to_batches(rew).float()
actions = to_batches(act).long()
obs = to_batches(obs).reshape([B, T, self.n_agents,
self.obs_size]).float()
action_mask = to_batches(action_mask)
next_obs = to_batches(next_obs).reshape(
[B, T, self.n_agents, self.obs_size]).float()
next_action_mask = to_batches(next_action_mask)
# TODO(ekl) this treats group termination as individual termination
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
B, T, self.n_agents)[:, :-1]
B, T, self.n_agents)
# Create mask for where index is < unpadded sequence length
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
np.expand_dims(seq_lens, 1)).astype(np.float32)
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
self.n_agents)[:, :-1]
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents)
# Compute loss
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
self.loss(rewards, actions, terminated, mask, obs, action_mask)
self.loss(rewards, actions, terminated, mask, obs,
next_obs, action_mask, next_action_mask)
# Optimise
self.optimiser.zero_grad()