import gym import logging import numpy as np import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple import ray from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import _unpack_obs from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import TensorType # Torch must be installed. torch, nn = try_import_torch(error=True) logger = logging.getLogger(__name__) class QMixLoss(nn.Module): def __init__( self, model, target_model, mixer, target_mixer, n_agents, n_actions, double_q=True, gamma=0.99, ): nn.Module.__init__(self) self.model = model self.target_model = target_model self.mixer = mixer self.target_mixer = target_mixer self.n_agents = n_agents self.n_actions = n_actions self.double_q = double_q self.gamma = gamma def forward( self, rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, state=None, next_state=None, ): """Forward pass of the loss. Args: rewards: Tensor of shape [B, T, n_agents] actions: Tensor of shape [B, T, n_agents] terminated: Tensor of shape [B, T, n_agents] mask: Tensor of shape [B, T, n_agents] obs: Tensor of shape [B, T, n_agents, obs_size] next_obs: Tensor of shape [B, T, n_agents, obs_size] action_mask: Tensor of shape [B, T, n_agents, n_actions] next_action_mask: Tensor of shape [B, T, n_agents, n_actions] state: Tensor of shape [B, T, state_dim] (optional) next_state: Tensor of shape [B, T, state_dim] (optional) """ # Assert either none or both of state and next_state are given if state is None and next_state is None: state = obs # default to state being all agents' observations next_state = next_obs elif (state is None) != (next_state is None): raise ValueError( "Expected either neither or both of `state` and " "`next_state` to be given. Got: " "\n`state` = {}\n`next_state` = {}".format(state, next_state) ) # Calculate estimated Q-Values mac_out = _unroll_mac(self.model, obs) # Pick the Q-Values for the actions taken -> [B * n_agents, T] chosen_action_qvals = torch.gather( mac_out, dim=3, index=actions.unsqueeze(3) ).squeeze(3) # Calculate the Q-Values necessary for the target target_mac_out = _unroll_mac(self.target_model, next_obs) # Mask out unavailable actions for the t+1 step ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1) target_mac_out[ignore_action_tp1] = -np.inf # Max over target Q-Values if self.double_q: # Double Q learning computes the target Q values by selecting the # t+1 timestep action according to the "policy" neural network and # then estimating the Q-value of that action with the "target" # neural network # Compute the t+1 Q-values to be used in action selection # using next_obs mac_out_tp1 = _unroll_mac(self.model, next_obs) # mask out unallowed actions mac_out_tp1[ignore_action_tp1] = -np.inf # obtain best actions at t+1 according to policy NN cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True) # use the target network to estimate the Q-values of policy # network's selected actions target_max_qvals = torch.gather(target_mac_out, 3, cur_max_actions).squeeze( 3 ) else: target_max_qvals = target_mac_out.max(dim=3)[0] assert ( target_max_qvals.min().item() != -np.inf ), "target_max_qvals contains a masked action; \ there may be a state with no valid actions." # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, state) target_max_qvals = self.target_mixer(target_max_qvals, next_state) # Calculate 1-step Q-Learning targets targets = rewards + self.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = chosen_action_qvals - targets.detach() mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error ** 2).sum() / mask.sum() return loss, mask, masked_td_error, chosen_action_qvals, targets # TODO(sven): Make this a TorchPolicy child via `build_policy_class`. class QMixTorchPolicy(TorchPolicy): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents together for QMix. This creates the proper Tuple obs/action spaces and populates the '_group_rewards' info field. Action masking: to specify an action mask for individual agents, use a dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}. The mask space must be `Box(0, 1, (n_actions,))`. """ def __init__(self, obs_space, action_space, config): _validate(obs_space, action_space) config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) self.framework = "torch" self.n_agents = len(obs_space.original_space.spaces) config["model"]["n_agents"] = self.n_agents self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] self.has_env_global_state = False self.has_action_mask = False self.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, gym.spaces.Dict): space_keys = set(agent_obs_space.spaces.keys()) if "obs" not in space_keys: raise ValueError("Dict obs space must have subspace labeled `obs`") self.obs_size = _get_size(agent_obs_space.spaces["obs"]) if "action_mask" in space_keys: mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) if mask_shape != (self.n_actions,): raise ValueError( "Action mask shape must be {}, got {}".format( (self.n_actions,), mask_shape ) ) self.has_action_mask = True if ENV_STATE in space_keys: self.env_global_state_shape = _get_size( agent_obs_space.spaces[ENV_STATE] ) self.has_env_global_state = True else: self.env_global_state_shape = (self.obs_size, self.n_agents) # The real agent obs space is nested inside the dict config["model"]["full_obs_space"] = agent_obs_space agent_obs_space = agent_obs_space.spaces["obs"] else: self.obs_size = _get_size(agent_obs_space) self.env_global_state_shape = (self.obs_size, self.n_agents) self.model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="model", default_model=RNNModel, ).to(self.device) self.target_model = ModelCatalog.get_model_v2( agent_obs_space, action_space.spaces[0], self.n_actions, config["model"], framework="torch", name="target_model", default_model=RNNModel, ).to(self.device) super().__init__(obs_space, action_space, config, model=self.model) self.exploration = self._create_exploration() # Setup the mixer network. if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": self.mixer = QMixer( self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"] ).to(self.device) self.target_mixer = QMixer( self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"] ).to(self.device) elif config["mixer"] == "vdn": self.mixer = VDNMixer().to(self.device) self.target_mixer = VDNMixer().to(self.device) else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) self.cur_epsilon = 1.0 self.update_target() # initial sync # Setup optimizer self.params = list(self.model.parameters()) if self.mixer: self.params += list(self.mixer.parameters()) self.loss = QMixLoss( self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"], ) from torch.optim import RMSprop self.optimiser = RMSprop( params=self.params, lr=config["lr"], alpha=config["optim_alpha"], eps=config["optim_eps"], ) @override(TorchPolicy) def compute_actions_from_input_dict( self, input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: obs_batch = input_dict[SampleBatch.OBS] state_batches = [] i = 0 while f"state_in_{i}" in input_dict: state_batches.append(input_dict[f"state_in_{i}"]) i += 1 explore = explore if explore is not None else self.config["explore"] obs_batch, action_mask, _ = self._unpack_observation(obs_batch) # We need to ensure we do not use the env global state # to compute actions # Compute actions with torch.no_grad(): q_values, hiddens = _mac( self.model, torch.as_tensor(obs_batch, dtype=torch.float, device=self.device), [ torch.as_tensor(np.array(s), dtype=torch.float, device=self.device) for s in state_batches ], ) avail = torch.as_tensor(action_mask, dtype=torch.float, device=self.device) masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") masked_q_values_folded = torch.reshape( masked_q_values, [-1] + list(masked_q_values.shape)[2:] ) actions, _ = self.exploration.get_exploration_action( action_distribution=TorchCategorical(masked_q_values_folded), timestep=timestep, explore=explore, ) actions = ( torch.reshape(actions, list(masked_q_values.shape)[:-1]).cpu().numpy() ) hiddens = [s.cpu().numpy() for s in hiddens] return tuple(actions.transpose([1, 0])), hiddens, {} @override(TorchPolicy) def compute_actions(self, *args, **kwargs): return self.compute_actions_from_input_dict(*args, **kwargs) @override(TorchPolicy) def compute_log_likelihoods( self, actions, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, ): obs_batch, action_mask, _ = self._unpack_observation(obs_batch) return np.zeros(obs_batch.size()[0]) @override(TorchPolicy) def learn_on_batch(self, samples): obs_batch, action_mask, env_global_state = self._unpack_observation( samples[SampleBatch.CUR_OBS] ) ( next_obs_batch, next_action_mask, next_env_global_state, ) = self._unpack_observation(samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) input_list = [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch, ] if self.has_env_global_state: input_list.extend([env_global_state, next_env_global_state]) output_list, _, seq_lens = chop_into_sequences( episode_ids=samples[SampleBatch.EPS_ID], unroll_ids=samples[SampleBatch.UNROLL_ID], agent_indices=samples[SampleBatch.AGENT_INDEX], feature_columns=input_list, state_columns=[], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True, ) # These will be padded to shape [B * T, ...] if self.has_env_global_state: ( rew, action_mask, next_action_mask, act, dones, obs, next_obs, env_global_state, next_env_global_state, ) = output_list else: ( rew, action_mask, next_action_mask, act, dones, obs, next_obs, ) = output_list B, T = len(seq_lens), max(seq_lens) def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) return torch.as_tensor( np.reshape(arr, new_shape), dtype=dtype, device=self.device ) rewards = to_batches(rew, torch.float) actions = to_batches(act, torch.long) obs = to_batches(obs, torch.float).reshape([B, T, self.n_agents, self.obs_size]) action_mask = to_batches(action_mask, torch.float) next_obs = to_batches(next_obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size] ) next_action_mask = to_batches(next_action_mask, torch.float) if self.has_env_global_state: env_global_state = to_batches(env_global_state, torch.float) next_env_global_state = to_batches(next_env_global_state, torch.float) # TODO(ekl) this treats group termination as individual termination terminated = ( to_batches(dones, torch.float).unsqueeze(2).expand(B, T, self.n_agents) ) # Create mask for where index is < unpadded sequence length filled = np.reshape( np.tile(np.arange(T, dtype=np.float32), B), [B, T] ) < np.expand_dims(seq_lens, 1) mask = ( torch.as_tensor(filled, dtype=torch.float, device=self.device) .unsqueeze(2) .expand(B, T, self.n_agents) ) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = self.loss( rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, env_global_state, next_env_global_state, ) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"] ) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {LEARNER_STATS_KEY: stats} @override(TorchPolicy) def get_initial_state(self): # initial RNN state return [ s.expand([self.n_agents, -1]).cpu().numpy() for s in self.model.get_initial_state() ] @override(TorchPolicy) def get_weights(self): return { "model": self._cpu_dict(self.model.state_dict()), "target_model": self._cpu_dict(self.target_model.state_dict()), "mixer": self._cpu_dict(self.mixer.state_dict()) if self.mixer else None, "target_mixer": self._cpu_dict(self.target_mixer.state_dict()) if self.mixer else None, } @override(TorchPolicy) def set_weights(self, weights): self.model.load_state_dict(self._device_dict(weights["model"])) self.target_model.load_state_dict(self._device_dict(weights["target_model"])) if weights["mixer"] is not None: self.mixer.load_state_dict(self._device_dict(weights["mixer"])) self.target_mixer.load_state_dict( self._device_dict(weights["target_mixer"]) ) @override(TorchPolicy) def get_state(self): state = self.get_weights() state["cur_epsilon"] = self.cur_epsilon return state @override(TorchPolicy) def set_state(self, state): self.set_weights(state) self.set_epsilon(state["cur_epsilon"]) def update_target(self): self.target_model.load_state_dict(self.model.state_dict()) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) logger.debug("Updated target networks") def set_epsilon(self, epsilon): self.cur_epsilon = epsilon def _get_group_rewards(self, info_batch): group_rewards = np.array( [info.get(GROUP_REWARDS, [0.0] * self.n_agents) for info in info_batch] ) return group_rewards def _device_dict(self, state_dict): return { k: torch.as_tensor(v, device=self.device) for k, v in state_dict.items() } @staticmethod def _cpu_dict(state_dict): return {k: v.cpu().detach().numpy() for k, v in state_dict.items()} def _unpack_observation(self, obs_batch): """Unpacks the observation, action mask, and state (if present) from agent grouping. Returns: obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size] mask (np.ndarray): action mask, if any state (np.ndarray or None): state tensor of shape [B, state_size] or None if it is not in the batch """ unpacked = _unpack_obs( np.array(obs_batch, dtype=np.float32), self.observation_space.original_space, tensorlib=np, ) if isinstance(unpacked[0], dict): assert "obs" in unpacked[0] unpacked_obs = [np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked] else: unpacked_obs = unpacked obs = np.concatenate(unpacked_obs, axis=1).reshape( [len(obs_batch), self.n_agents, self.obs_size] ) if self.has_action_mask: action_mask = np.concatenate( [o["action_mask"] for o in unpacked], axis=1 ).reshape([len(obs_batch), self.n_agents, self.n_actions]) else: action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions], dtype=np.float32 ) if self.has_env_global_state: state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1) else: state = None return obs, action_mask, state def _validate(obs_space, action_space): if not hasattr(obs_space, "original_space") or not isinstance( obs_space.original_space, gym.spaces.Tuple ): raise ValueError( "Obs space must be a Tuple, got {}. Use ".format(obs_space) + "MultiAgentEnv.with_agent_groups() to group related " "agents for QMix." ) if not isinstance(action_space, gym.spaces.Tuple): raise ValueError( "Action space must be a Tuple, got {}. ".format(action_space) + "Use MultiAgentEnv.with_agent_groups() to group related " "agents for QMix." ) if not isinstance(action_space.spaces[0], gym.spaces.Discrete): raise ValueError( "QMix requires a discrete action space, got {}".format( action_space.spaces[0] ) ) if len({str(x) for x in obs_space.original_space.spaces}) > 1: raise ValueError( "Implementation limitation: observations of grouped agents " "must be homogeneous, got {}".format(obs_space.original_space.spaces) ) if len({str(x) for x in action_space.spaces}) > 1: raise ValueError( "Implementation limitation: action space of grouped agents " "must be homogeneous, got {}".format(action_space.spaces) ) def _mac(model, obs, h): """Forward pass of the multi-agent controller. Args: model: TorchModelV2 class obs: Tensor of shape [B, n_agents, obs_size] h: List of tensors of shape [B, n_agents, h_size] Returns: q_vals: Tensor of shape [B, n_agents, n_actions] h: Tensor of shape [B, n_agents, h_size] """ B, n_agents = obs.size(0), obs.size(1) if not isinstance(obs, dict): obs = {"obs": obs} obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()} h_flat = [s.reshape([B * n_agents, -1]) for s in h] q_flat, h_flat = model(obs_agents_as_batches, h_flat, None) return q_flat.reshape([B, n_agents, -1]), [ s.reshape([B, n_agents, -1]) for s in h_flat ] def _unroll_mac(model, obs_tensor): """Computes the estimated Q values for an entire trajectory batch""" B = obs_tensor.size(0) T = obs_tensor.size(1) n_agents = obs_tensor.size(2) mac_out = [] h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()] for t in range(T): q, h = _mac(model, obs_tensor[:, t], h) mac_out.append(q) mac_out = torch.stack(mac_out, dim=1) # Concat over time return mac_out def _drop_agent_dim(T): shape = list(T.shape) B, n_agents = shape[0], shape[1] return T.reshape([B * n_agents] + shape[2:]) def _add_agent_dim(T, n_agents): shape = list(T.shape) B = shape[0] // n_agents assert shape[0] % n_agents == 0 return T.reshape([B, n_agents] + shape[1:])