from gym.spaces import Tuple, Discrete, Dict import logging import numpy as np import torch as th import torch.nn as nn from torch.optim import RMSprop from torch.distributions import Categorical import ray from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.model import _unpack_obs from ray.rllib.env.constants import GROUP_REWARDS from ray.rllib.utils.annotations import override from ray.rllib.utils.tuple_actions import TupleActions logger = logging.getLogger(__name__) # if the obs space is Dict type, look for the global state under this key ENV_STATE = "state" class QMixLoss(nn.Module): def __init__(self, model, target_model, mixer, target_mixer, n_agents, n_actions, double_q=True, gamma=0.99): nn.Module.__init__(self) self.model = model self.target_model = target_model self.mixer = mixer self.target_mixer = target_mixer self.n_agents = n_agents self.n_actions = n_actions self.double_q = double_q self.gamma = gamma def forward(self, rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, state=None, next_state=None): """Forward pass of the loss. Arguments: rewards: Tensor of shape [B, T, n_agents] actions: Tensor of shape [B, T, n_agents] terminated: Tensor of shape [B, T, n_agents] mask: Tensor of shape [B, T, n_agents] obs: Tensor of shape [B, T, n_agents, obs_size] next_obs: Tensor of shape [B, T, n_agents, obs_size] action_mask: Tensor of shape [B, T, n_agents, n_actions] next_action_mask: Tensor of shape [B, T, n_agents, n_actions] state: Tensor of shape [B, T, state_dim] (optional) next_state: Tensor of shape [B, T, state_dim] (optional) """ # Assert either none or both of state and next_state are given if state is None and next_state is None: state = obs # default to state being all agents' observations next_state = next_obs elif (state is None) != (next_state is None): raise ValueError("Expected either neither or both of `state` and " "`next_state` to be given. Got: " "\n`state` = {}\n`next_state` = {}".format( state, next_state)) # Calculate estimated Q-Values mac_out = _unroll_mac(self.model, obs) # Pick the Q-Values for the actions taken -> [B * n_agents, T] chosen_action_qvals = th.gather( mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3) # Calculate the Q-Values necessary for the target target_mac_out = _unroll_mac(self.target_model, next_obs) # Mask out unavailable actions for the t+1 step ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1) target_mac_out[ignore_action_tp1] = -np.inf # Max over target Q-Values if self.double_q: # Double Q learning computes the target Q values by selecting the # t+1 timestep action according to the "policy" neural network and # then estimating the Q-value of that action with the "target" # neural network # Compute the t+1 Q-values to be used in action selection # using next_obs mac_out_tp1 = _unroll_mac(self.model, next_obs) # mask out unallowed actions mac_out_tp1[ignore_action_tp1] = -np.inf # obtain best actions at t+1 according to policy NN cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True) # use the target network to estimate the Q-values of policy # network's selected actions target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] assert target_max_qvals.min().item() != -np.inf, \ "target_max_qvals contains a masked action; \ there may be a state with no valid actions." # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, state) target_max_qvals = self.target_mixer(target_max_qvals, next_state) # Calculate 1-step Q-Learning targets targets = rewards + self.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() return loss, mask, masked_td_error, chosen_action_qvals, targets # TODO(sven): Make this a TorchPolicy child. class QMixTorchPolicy(Policy): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents together for QMix. This creates the proper Tuple obs/action spaces and populates the '_group_rewards' info field. Action masking: to specify an action mask for individual agents, use a dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}. The mask space must be `Box(0, 1, (n_actions,))`. """ def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.framework = "torch" super().__init__(obs_space, action_space, config) self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] self.has_env_global_state = False self.has_action_mask = False self.device = (th.device("cuda") if th.cuda.is_available() else th.device("cpu")) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) if "obs" not in space_keys: raise ValueError( "Dict obs space must have subspace labeled `obs`") self.obs_size = _get_size(agent_obs_space.spaces["obs"]) if "action_mask" in space_keys: mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions, ): raise ValueError( "Action mask shape must be {}, got {}".format( (self.n_actions, ), mask_shape)) self.has_action_mask = True if ENV_STATE in space_keys: self.env_global_state_shape = _get_size( agent_obs_space.spaces[ENV_STATE]) self.has_env_global_state = True else: self.env_global_state_shape = (self.obs_size, self.n_agents) # The real agent obs space is nested inside the dict config["model"]["full_obs_space"] = agent_obs_space agent_obs_space = agent_obs_space.spaces["obs"] else: self.obs_size = _get_size(agent_obs_space) self.model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="model", default_model=RNNModel).to(self.device) self.target_model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="target_model", default_model=RNNModel).to(self.device) # Setup the mixer network. if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer(self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]).to(self.device) self.target_mixer = QMixer( self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]).to(self.device) elif config["mixer"] == "vdn": self.mixer = VDNMixer().to(self.device) self.target_mixer = VDNMixer().to(self.device) else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) if self.mixer: self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) self.optimiser = RMSprop( params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"]) @override(Policy) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, **kwargs): explore = explore if explore is not None else self.config["explore"] obs_batch, action_mask, _ = self._unpack_observation(obs_batch) # We need to ensure we do not use the env global state # to compute actions # Compute actions with th.no_grad(): q_values, hiddens = _mac( self.model, th.as_tensor(obs_batch, dtype=th.float, device=self.device), [ th.as_tensor( np.array(s), dtype=th.float, device=self.device) for s in state_batches ]) avail = th.as_tensor( action_mask, dtype=th.float, device=self.device) masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") # epsilon-greedy action selector random_numbers = th.rand_like(q_values[:, :, 0]) pick_random = (random_numbers < (self.cur_epsilon if explore else 0.0)).long() random_actions = Categorical(avail).sample().long() actions = (pick_random * random_actions + (1 - pick_random) * masked_q_values.argmax(dim=2)) actions = actions.cpu().numpy() hiddens = [s.cpu().numpy() for s in hiddens] return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} @override(Policy) def learn_on_batch(self, samples): obs_batch, action_mask, env_global_state = self._unpack_observation( samples[SampleBatch.CUR_OBS]) (next_obs_batch, next_action_mask, next_env_global_state) = self._unpack_observation( samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) input_list = [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch ] if self.has_env_global_state: input_list.extend([env_global_state, next_env_global_state]) output_list, _, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], input_list, [], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) # These will be padded to shape [B * T, ...] if self.has_env_global_state: (rew, action_mask, next_action_mask, act, dones, obs, next_obs, env_global_state, next_env_global_state) = output_list else: (rew, action_mask, next_action_mask, act, dones, obs, next_obs) = output_list B, T = len(seq_lens), max(seq_lens) def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) return th.as_tensor( np.reshape(arr, new_shape), dtype=dtype, device=self.device) rewards = to_batches(rew, th.float) actions = to_batches(act, th.long) obs = to_batches(obs, th.float).reshape( [B, T, self.n_agents, self.obs_size]) action_mask = to_batches(action_mask, th.float) next_obs = to_batches(next_obs, th.float).reshape( [B, T, self.n_agents, self.obs_size]) next_action_mask = to_batches(next_action_mask, th.float) if self.has_env_global_state: env_global_state = to_batches(env_global_state, th.float) next_env_global_state = to_batches(next_env_global_state, th.float) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones, th.float).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length filled = np.reshape( np.tile(np.arange(T, dtype=np.float32), B), [B, T]) < np.expand_dims(seq_lens, 1) mask = th.as_tensor( filled, dtype=th.float, device=self.device).unsqueeze(2).expand( B, T, self.n_agents) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = ( self.loss(rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, env_global_state, next_env_global_state)) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = th.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {LEARNER_STATS_KEY: stats} @override(Policy) def get_initial_state(self): # initial RNN state return [ s.expand([self.n_agents, -1]).cpu().numpy() for s in self.model.get_initial_state() ] @override(Policy) def get_weights(self): return { "model": self._cpu_dict(self.model.state_dict()), "target_model": self._cpu_dict(self.target_model.state_dict()), "mixer": self._cpu_dict(self.mixer.state_dict()) if self.mixer else None, "target_mixer": self._cpu_dict(self.target_mixer.state_dict()) if self.mixer else None, } @override(Policy) def set_weights(self, weights): self.model.load_state_dict(self._device_dict(weights["model"])) self.target_model.load_state_dict( self._device_dict(weights["target_model"])) if weights["mixer"] is not None: self.mixer.load_state_dict(self._device_dict(weights["mixer"])) self.target_mixer.load_state_dict( self._device_dict(weights["target_mixer"])) @override(Policy) def get_state(self): state = self.get_weights() state["cur_epsilon"] = self.cur_epsilon return state @override(Policy) def set_state(self, state): self.set_weights(state) self.set_epsilon(state["cur_epsilon"]) def update_target(self): self.target_model.load_state_dict(self.model.state_dict()) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) logger.debug("Updated target networks") def set_epsilon(self, epsilon): self.cur_epsilon = epsilon def _get_group_rewards(self, info_batch): group_rewards = np.array([ info.get(GROUP_REWARDS, [0.0] * self.n_agents) for info in info_batch ]) return group_rewards def _device_dict(self, state_dict): return { k: th.as_tensor(v, device=self.device) for k, v in state_dict.items() } @staticmethod def _cpu_dict(state_dict): return {k: v.cpu().detach().numpy() for k, v in state_dict.items()} def _unpack_observation(self, obs_batch): """Unpacks the observation, action mask, and state (if present) from agent grouping. Returns: obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size] mask (np.ndarray): action mask, if any state (np.ndarray or None): state tensor of shape [B, state_size] or None if it is not in the batch """ unpacked = _unpack_obs( np.array(obs_batch, dtype=np.float32), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate( [o["obs"] for o in unpacked], axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate( [o["action_mask"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.n_actions]) else: if isinstance(unpacked[0], dict): unpacked_obs = [u["obs"] for u in unpacked] else: unpacked_obs = unpacked obs = np.concatenate( unpacked_obs, axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions], dtype=np.float32) if self.has_env_global_state: state = unpacked[0][ENV_STATE] else: state = None return obs, action_mask, state def _validate(obs_space, action_space): if not hasattr(obs_space, "original_space") or \ not isinstance(obs_space.original_space, Tuple): raise ValueError("Obs space must be a Tuple, got {}. Use ".format( obs_space) + "MultiAgentEnv.with_agent_groups() to group related " "agents for QMix.") if not isinstance(action_space, Tuple): raise ValueError( "Action space must be a Tuple, got {}. ".format(action_space) + "Use MultiAgentEnv.with_agent_groups() to group related " "agents for QMix.") if not isinstance(action_space.spaces[0], Discrete): raise ValueError( "QMix requires a discrete action space, got {}".format( action_space.spaces[0])) if len({str(x) for x in obs_space.original_space.spaces}) > 1: raise ValueError( "Implementation limitation: observations of grouped agents " "must be homogeneous, got {}".format( obs_space.original_space.spaces)) if len({str(x) for x in action_space.spaces}) > 1: raise ValueError( "Implementation limitation: action space of grouped agents " "must be homogeneous, got {}".format(action_space.spaces)) def _mac(model, obs, h): """Forward pass of the multi-agent controller. Arguments: model: TorchModelV2 class obs: Tensor of shape [B, n_agents, obs_size] h: List of tensors of shape [B, n_agents, h_size] Returns: q_vals: Tensor of shape [B, n_agents, n_actions] h: Tensor of shape [B, n_agents, h_size] """ B, n_agents = obs.size(0), obs.size(1) if not isinstance(obs, dict): obs = {"obs": obs} obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()} h_flat = [s.reshape([B * n_agents, -1]) for s in h] q_flat, h_flat = model(obs_agents_as_batches, h_flat, None) return q_flat.reshape( [B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat] def _unroll_mac(model, obs_tensor): """Computes the estimated Q values for an entire trajectory batch""" B = obs_tensor.size(0) T = obs_tensor.size(1) n_agents = obs_tensor.size(2) mac_out = [] h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()] for t in range(T): q, h = _mac(model, obs_tensor[:, t], h) mac_out.append(q) mac_out = th.stack(mac_out, dim=1) # Concat over time return mac_out def _drop_agent_dim(T): shape = list(T.shape) B, n_agents = shape[0], shape[1] return T.reshape([B * n_agents] + shape[2:]) def _add_agent_dim(T, n_agents): shape = list(T.shape) B = shape[0] // n_agents assert shape[0] % n_agents == 0 return T.reshape([B, n_agents] + shape[1:])