mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Qmix padding patch (#4735)
* Qmix padding patch * Update qmix_policy_graph.py * lint errors * more linting * Update qmix_policy_graph.py
This commit is contained in:
parent
edb8465910
commit
28496c8b50
1 changed files with 45 additions and 32 deletions
|
@ -46,16 +46,19 @@ class QMixLoss(nn.Module):
|
|||
self.double_q = double_q
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, rewards, actions, terminated, mask, obs, action_mask):
|
||||
def forward(self, rewards, actions, terminated, mask, obs, next_obs,
|
||||
action_mask, next_action_mask):
|
||||
"""Forward pass of the loss.
|
||||
|
||||
Arguments:
|
||||
rewards: Tensor of shape [B, T-1, n_agents]
|
||||
actions: Tensor of shape [B, T-1, n_agents]
|
||||
terminated: Tensor of shape [B, T-1, n_agents]
|
||||
mask: Tensor of shape [B, T-1, n_agents]
|
||||
rewards: Tensor of shape [B, T, n_agents]
|
||||
actions: Tensor of shape [B, T, n_agents]
|
||||
terminated: Tensor of shape [B, T, n_agents]
|
||||
mask: Tensor of shape [B, T, n_agents]
|
||||
obs: Tensor of shape [B, T, n_agents, obs_size]
|
||||
next_obs: Tensor of shape [B, T, n_agents, obs_size]
|
||||
action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
"""
|
||||
|
||||
B, T = obs.size(0), obs.size(1)
|
||||
|
@ -68,9 +71,9 @@ class QMixLoss(nn.Module):
|
|||
mac_out.append(q)
|
||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
||||
|
||||
# Pick the Q-Values for the actions taken -> [B * n_agents, T-1]
|
||||
# Pick the Q-Values for the actions taken -> [B * n_agents, T]
|
||||
chosen_action_qvals = th.gather(
|
||||
mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
||||
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
||||
|
||||
# Calculate the Q-Values necessary for the target
|
||||
target_mac_out = []
|
||||
|
@ -79,32 +82,37 @@ class QMixLoss(nn.Module):
|
|||
for s in self.target_model.state_init()
|
||||
]
|
||||
for t in range(T):
|
||||
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
|
||||
target_q, target_h = _mac(self.target_model, next_obs[:, t],
|
||||
target_h)
|
||||
target_mac_out.append(target_q)
|
||||
|
||||
# We don't need the first timesteps Q-Value estimate for targets
|
||||
target_mac_out = th.stack(
|
||||
target_mac_out[1:], dim=1) # Concat across time
|
||||
target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time
|
||||
|
||||
# Mask out unavailable actions
|
||||
target_mac_out[action_mask[:, 1:] == 0] = -9999999
|
||||
ignore_action = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
|
||||
target_mac_out[ignore_action] = -np.inf
|
||||
|
||||
# Max over target Q-Values
|
||||
if self.double_q:
|
||||
# Get actions that maximise live Q (for double q-learning)
|
||||
mac_out[action_mask == 0] = -9999999
|
||||
cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
|
||||
ignore_action = (action_mask == 0) & (mask == 1).unsqueeze(-1)
|
||||
mac_out = mac_out.clone() # issue 4742
|
||||
mac_out[ignore_action] = -np.inf
|
||||
cur_max_actions = mac_out.max(dim=3, keepdim=True)[1]
|
||||
target_max_qvals = th.gather(target_mac_out, 3,
|
||||
cur_max_actions).squeeze(3)
|
||||
else:
|
||||
target_max_qvals = target_mac_out.max(dim=3)[0]
|
||||
|
||||
assert target_max_qvals.min().item() != -np.inf, \
|
||||
"target_max_qvals contains a masked action; \
|
||||
there may be a state with no valid actions."
|
||||
|
||||
# Mix
|
||||
if self.mixer is not None:
|
||||
# TODO(ekl) add support for handling global state? This is just
|
||||
# treating the stacked agent obs as the state.
|
||||
chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1])
|
||||
target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:])
|
||||
chosen_action_qvals = self.mixer(chosen_action_qvals, obs)
|
||||
target_max_qvals = self.target_mixer(target_max_qvals, next_obs)
|
||||
|
||||
# Calculate 1-step Q-Learning targets
|
||||
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
|
||||
|
@ -239,48 +247,53 @@ class QMixPolicyGraph(PolicyGraph):
|
|||
def learn_on_batch(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(
|
||||
samples[SampleBatch.CUR_OBS])
|
||||
next_obs_batch, next_action_mask = self._unpack_observation(
|
||||
samples[SampleBatch.NEXT_OBS])
|
||||
group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])
|
||||
|
||||
# These will be padded to shape [B * T, ...]
|
||||
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
|
||||
[rew, action_mask, next_action_mask, act, dones, obs, next_obs], \
|
||||
initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples[SampleBatch.EPS_ID],
|
||||
samples[SampleBatch.UNROLL_ID],
|
||||
samples[SampleBatch.AGENT_INDEX], [
|
||||
group_rewards, action_mask, samples[SampleBatch.ACTIONS],
|
||||
samples[SampleBatch.DONES], obs_batch
|
||||
group_rewards, action_mask, next_action_mask,
|
||||
samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
|
||||
obs_batch, next_obs_batch
|
||||
],
|
||||
[samples["state_in_{}".format(k)]
|
||||
for k in range(len(self.get_initial_state()))],
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
dynamic_max=True,
|
||||
_extra_padding=1)
|
||||
# TODO(ekl) adding 1 extra unit of padding here, since otherwise we
|
||||
# lose the terminating reward and the Q-values will be unanchored!
|
||||
B, T = len(seq_lens), max(seq_lens) + 1
|
||||
dynamic_max=True)
|
||||
B, T = len(seq_lens), max(seq_lens)
|
||||
|
||||
def to_batches(arr):
|
||||
new_shape = [B, T] + list(arr.shape[1:])
|
||||
return th.from_numpy(np.reshape(arr, new_shape))
|
||||
|
||||
rewards = to_batches(rew)[:, :-1].float()
|
||||
actions = to_batches(act)[:, :-1].long()
|
||||
rewards = to_batches(rew).float()
|
||||
actions = to_batches(act).long()
|
||||
obs = to_batches(obs).reshape([B, T, self.n_agents,
|
||||
self.obs_size]).float()
|
||||
action_mask = to_batches(action_mask)
|
||||
next_obs = to_batches(next_obs).reshape(
|
||||
[B, T, self.n_agents, self.obs_size]).float()
|
||||
next_action_mask = to_batches(next_action_mask)
|
||||
|
||||
# TODO(ekl) this treats group termination as individual termination
|
||||
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
|
||||
B, T, self.n_agents)[:, :-1]
|
||||
B, T, self.n_agents)
|
||||
|
||||
# Create mask for where index is < unpadded sequence length
|
||||
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
|
||||
np.expand_dims(seq_lens, 1)).astype(np.float32)
|
||||
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
|
||||
self.n_agents)[:, :-1]
|
||||
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
|
||||
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents)
|
||||
|
||||
# Compute loss
|
||||
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
|
||||
self.loss(rewards, actions, terminated, mask, obs, action_mask)
|
||||
self.loss(rewards, actions, terminated, mask, obs,
|
||||
next_obs, action_mask, next_action_mask)
|
||||
|
||||
# Optimise
|
||||
self.optimiser.zero_grad()
|
||||
|
|
Loading…
Add table
Reference in a new issue