import gym
import logging
import numpy as np
import tree  # pip install dm_tree
from typing import Dict, List, Optional, Tuple

import ray
from ray.rllib.algorithms.qmix.mixers import VDNMixer, QMixer
from ray.rllib.algorithms.qmix.model import RNNModel, _get_size
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.env.wrappers.group_agents_wrapper import GROUP_REWARDS
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import _unpack_obs
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.torch_utils import apply_grad_clipping

# Torch must be installed.
torch, nn = try_import_torch(error=True)

logger = logging.getLogger(__name__)


class QMixLoss(nn.Module):
    def __init__(
        self,
        model,
        target_model,
        mixer,
        target_mixer,
        n_agents,
        n_actions,
        double_q=True,
        gamma=0.99,
    ):
        nn.Module.__init__(self)
        self.model = model
        self.target_model = target_model
        self.mixer = mixer
        self.target_mixer = target_mixer
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.double_q = double_q
        self.gamma = gamma

    def forward(
        self,
        rewards,
        actions,
        terminated,
        mask,
        obs,
        next_obs,
        action_mask,
        next_action_mask,
        state=None,
        next_state=None,
    ):
        """Forward pass of the loss.

        Args:
            rewards: Tensor of shape [B, T, n_agents]
            actions: Tensor of shape [B, T, n_agents]
            terminated: Tensor of shape [B, T, n_agents]
            mask: Tensor of shape [B, T, n_agents]
            obs: Tensor of shape [B, T, n_agents, obs_size]
            next_obs: Tensor of shape [B, T, n_agents, obs_size]
            action_mask: Tensor of shape [B, T, n_agents, n_actions]
            next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
            state: Tensor of shape [B, T, state_dim] (optional)
            next_state: Tensor of shape [B, T, state_dim] (optional)
        """

        # Assert either none or both of state and next_state are given
        if state is None and next_state is None:
            state = obs  # default to state being all agents' observations
            next_state = next_obs
        elif (state is None) != (next_state is None):
            raise ValueError(
                "Expected either neither or both of `state` and "
                "`next_state` to be given. Got: "
                "\n`state` = {}\n`next_state` = {}".format(state, next_state)
            )

        # Calculate estimated Q-Values
        mac_out = _unroll_mac(self.model, obs)

        # Pick the Q-Values for the actions taken -> [B * n_agents, T]
        chosen_action_qvals = torch.gather(
            mac_out, dim=3, index=actions.unsqueeze(3)
        ).squeeze(3)

        # Calculate the Q-Values necessary for the target
        target_mac_out = _unroll_mac(self.target_model, next_obs)

        # Mask out unavailable actions for the t+1 step
        ignore_action_tp1 = (next_action_mask == 0) & (mask == 1).unsqueeze(-1)
        target_mac_out[ignore_action_tp1] = -np.inf

        # Max over target Q-Values
        if self.double_q:
            # Double Q learning computes the target Q values by selecting the
            # t+1 timestep action according to the "policy" neural network and
            # then estimating the Q-value of that action with the "target"
            # neural network

            # Compute the t+1 Q-values to be used in action selection
            # using next_obs
            mac_out_tp1 = _unroll_mac(self.model, next_obs)

            # mask out unallowed actions
            mac_out_tp1[ignore_action_tp1] = -np.inf

            # obtain best actions at t+1 according to policy NN
            cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True)

            # use the target network to estimate the Q-values of policy
            # network's selected actions
            target_max_qvals = torch.gather(target_mac_out, 3, cur_max_actions).squeeze(
                3
            )
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        assert (
            target_max_qvals.min().item() != -np.inf
        ), "target_max_qvals contains a masked action; \
            there may be a state with no valid actions."

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals, state)
            target_max_qvals = self.target_mixer(target_max_qvals, next_state)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.gamma * (1 - terminated) * target_max_qvals

        # Td-error
        td_error = chosen_action_qvals - targets.detach()

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error ** 2).sum() / mask.sum()
        return loss, mask, masked_td_error, chosen_action_qvals, targets


class QMixTorchPolicy(TorchPolicy):
    """QMix impl. Assumes homogeneous agents for now.

    You must use MultiAgentEnv.with_agent_groups() to group agents
    together for QMix. This creates the proper Tuple obs/action spaces and
    populates the '_group_rewards' info field.

    Action masking: to specify an action mask for individual agents, use a
    dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
    The mask space must be `Box(0, 1, (n_actions,))`.
    """

    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.algorithms.qmix.qmix.DEFAULT_CONFIG, **config)
        self.framework = "torch"

        self.n_agents = len(obs_space.original_space.spaces)
        config["model"]["n_agents"] = self.n_agents
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]
        self.has_env_global_state = False
        self.has_action_mask = False

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, gym.spaces.Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if "obs" not in space_keys:
                raise ValueError("Dict obs space must have subspace labeled `obs`")
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            if "action_mask" in space_keys:
                mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
                if mask_shape != (self.n_actions,):
                    raise ValueError(
                        "Action mask shape must be {}, got {}".format(
                            (self.n_actions,), mask_shape
                        )
                    )
                self.has_action_mask = True
            if ENV_STATE in space_keys:
                self.env_global_state_shape = _get_size(
                    agent_obs_space.spaces[ENV_STATE]
                )
                self.has_env_global_state = True
            else:
                self.env_global_state_shape = (self.obs_size, self.n_agents)
            # The real agent obs space is nested inside the dict
            config["model"]["full_obs_space"] = agent_obs_space
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.obs_size = _get_size(agent_obs_space)
            self.env_global_state_shape = (self.obs_size, self.n_agents)

        self.model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="model",
            default_model=RNNModel,
        )

        super().__init__(obs_space, action_space, config, model=self.model)

        self.target_model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="target_model",
            default_model=RNNModel,
        ).to(self.device)

        self.exploration = self._create_exploration()

        # Setup the mixer network.
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(
                self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]
            ).to(self.device)
            self.target_mixer = QMixer(
                self.n_agents, self.env_global_state_shape, config["mixing_embed_dim"]
            ).to(self.device)
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer().to(self.device)
            self.target_mixer = VDNMixer().to(self.device)
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        if self.mixer:
            self.params += list(self.mixer.parameters())
        self.loss = QMixLoss(
            self.model,
            self.target_model,
            self.mixer,
            self.target_mixer,
            self.n_agents,
            self.n_actions,
            self.config["double_q"],
            self.config["gamma"],
        )
        from torch.optim import RMSprop

        self.rmsprop_optimizer = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"],
        )

    @override(TorchPolicy)
    def compute_actions_from_input_dict(
        self,
        input_dict: Dict[str, TensorType],
        explore: bool = None,
        timestep: Optional[int] = None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        obs_batch = input_dict[SampleBatch.OBS]
        state_batches = []
        i = 0
        while f"state_in_{i}" in input_dict:
            state_batches.append(input_dict[f"state_in_{i}"])
            i += 1

        explore = explore if explore is not None else self.config["explore"]
        obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
        # We need to ensure we do not use the env global state
        # to compute actions

        # Compute actions
        with torch.no_grad():
            q_values, hiddens = _mac(
                self.model,
                torch.as_tensor(obs_batch, dtype=torch.float, device=self.device),
                [
                    torch.as_tensor(np.array(s), dtype=torch.float, device=self.device)
                    for s in state_batches
                ],
            )
            avail = torch.as_tensor(action_mask, dtype=torch.float, device=self.device)
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            masked_q_values_folded = torch.reshape(
                masked_q_values, [-1] + list(masked_q_values.shape)[2:]
            )
            actions, _ = self.exploration.get_exploration_action(
                action_distribution=TorchCategorical(masked_q_values_folded),
                timestep=timestep,
                explore=explore,
            )
            actions = (
                torch.reshape(actions, list(masked_q_values.shape)[:-1]).cpu().numpy()
            )
            hiddens = [s.cpu().numpy() for s in hiddens]

        return tuple(actions.transpose([1, 0])), hiddens, {}

    @override(TorchPolicy)
    def compute_actions(self, *args, **kwargs):
        return self.compute_actions_from_input_dict(*args, **kwargs)

    @override(TorchPolicy)
    def compute_log_likelihoods(
        self,
        actions,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
    ):
        obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
        return np.zeros(obs_batch.size()[0])

    @override(TorchPolicy)
    def learn_on_batch(self, samples):
        obs_batch, action_mask, env_global_state = self._unpack_observation(
            samples[SampleBatch.CUR_OBS]
        )
        (
            next_obs_batch,
            next_action_mask,
            next_env_global_state,
        ) = self._unpack_observation(samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        input_list = [
            group_rewards,
            action_mask,
            next_action_mask,
            samples[SampleBatch.ACTIONS],
            samples[SampleBatch.DONES],
            obs_batch,
            next_obs_batch,
        ]
        if self.has_env_global_state:
            input_list.extend([env_global_state, next_env_global_state])

        output_list, _, seq_lens = chop_into_sequences(
            episode_ids=samples[SampleBatch.EPS_ID],
            unroll_ids=samples[SampleBatch.UNROLL_ID],
            agent_indices=samples[SampleBatch.AGENT_INDEX],
            feature_columns=input_list,
            state_columns=[],  # RNN states not used here
            max_seq_len=self.config["model"]["max_seq_len"],
            dynamic_max=True,
        )
        # These will be padded to shape [B * T, ...]
        if self.has_env_global_state:
            (
                rew,
                action_mask,
                next_action_mask,
                act,
                dones,
                obs,
                next_obs,
                env_global_state,
                next_env_global_state,
            ) = output_list
        else:
            (
                rew,
                action_mask,
                next_action_mask,
                act,
                dones,
                obs,
                next_obs,
            ) = output_list
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr, dtype):
            new_shape = [B, T] + list(arr.shape[1:])
            return torch.as_tensor(
                np.reshape(arr, new_shape), dtype=dtype, device=self.device
            )

        rewards = to_batches(rew, torch.float)
        actions = to_batches(act, torch.long)
        obs = to_batches(obs, torch.float).reshape([B, T, self.n_agents, self.obs_size])
        action_mask = to_batches(action_mask, torch.float)
        next_obs = to_batches(next_obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size]
        )
        next_action_mask = to_batches(next_action_mask, torch.float)
        if self.has_env_global_state:
            env_global_state = to_batches(env_global_state, torch.float)
            next_env_global_state = to_batches(next_env_global_state, torch.float)

        # TODO(ekl) this treats group termination as individual termination
        terminated = (
            to_batches(dones, torch.float).unsqueeze(2).expand(B, T, self.n_agents)
        )

        # Create mask for where index is < unpadded sequence length
        filled = np.reshape(
            np.tile(np.arange(T, dtype=np.float32), B), [B, T]
        ) < np.expand_dims(seq_lens, 1)
        mask = (
            torch.as_tensor(filled, dtype=torch.float, device=self.device)
            .unsqueeze(2)
            .expand(B, T, self.n_agents)
        )

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = self.loss(
            rewards,
            actions,
            terminated,
            mask,
            obs,
            next_obs,
            action_mask,
            next_action_mask,
            env_global_state,
            next_env_global_state,
        )

        # Optimise
        self.rmsprop_optimizer.zero_grad()
        loss_out.backward()
        grad_norm_info = apply_grad_clipping(self, self.rmsprop_optimizer, loss_out)
        self.rmsprop_optimizer.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss": loss_out.item(),
            "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        stats.update(grad_norm_info)

        return {LEARNER_STATS_KEY: stats}

    @override(TorchPolicy)
    def get_initial_state(self):  # initial RNN state
        return [
            s.expand([self.n_agents, -1]).cpu().numpy()
            for s in self.model.get_initial_state()
        ]

    @override(TorchPolicy)
    def get_weights(self):
        return {
            "model": self._cpu_dict(self.model.state_dict()),
            "target_model": self._cpu_dict(self.target_model.state_dict()),
            "mixer": self._cpu_dict(self.mixer.state_dict()) if self.mixer else None,
            "target_mixer": self._cpu_dict(self.target_mixer.state_dict())
            if self.mixer
            else None,
        }

    @override(TorchPolicy)
    def set_weights(self, weights):
        self.model.load_state_dict(self._device_dict(weights["model"]))
        self.target_model.load_state_dict(self._device_dict(weights["target_model"]))
        if weights["mixer"] is not None:
            self.mixer.load_state_dict(self._device_dict(weights["mixer"]))
            self.target_mixer.load_state_dict(
                self._device_dict(weights["target_mixer"])
            )

    @override(TorchPolicy)
    def get_state(self):
        state = self.get_weights()
        state["cur_epsilon"] = self.cur_epsilon
        return state

    @override(TorchPolicy)
    def set_state(self, state):
        self.set_weights(state)
        self.set_epsilon(state["cur_epsilon"])

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        logger.debug("Updated target networks")

    def set_epsilon(self, epsilon):
        self.cur_epsilon = epsilon

    def _get_group_rewards(self, info_batch):
        group_rewards = np.array(
            [info.get(GROUP_REWARDS, [0.0] * self.n_agents) for info in info_batch]
        )
        return group_rewards

    def _device_dict(self, state_dict):
        return {
            k: torch.as_tensor(v, device=self.device) for k, v in state_dict.items()
        }

    @staticmethod
    def _cpu_dict(state_dict):
        return {k: v.cpu().detach().numpy() for k, v in state_dict.items()}

    def _unpack_observation(self, obs_batch):
        """Unpacks the observation, action mask, and state (if present)
        from agent grouping.

        Returns:
            obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
            mask (np.ndarray): action mask, if any
            state (np.ndarray or None): state tensor of shape [B, state_size]
                or None if it is not in the batch
        """

        unpacked = _unpack_obs(
            np.array(obs_batch, dtype=np.float32),
            self.observation_space.original_space,
            tensorlib=np,
        )

        if isinstance(unpacked[0], dict):
            assert "obs" in unpacked[0]
            unpacked_obs = [np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked]
        else:
            unpacked_obs = unpacked

        obs = np.concatenate(unpacked_obs, axis=1).reshape(
            [len(obs_batch), self.n_agents, self.obs_size]
        )

        if self.has_action_mask:
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1
            ).reshape([len(obs_batch), self.n_agents, self.n_actions])
        else:
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions], dtype=np.float32
            )

        if self.has_env_global_state:
            state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
        else:
            state = None
        return obs, action_mask, state


def _validate(obs_space, action_space):
    if not hasattr(obs_space, "original_space") or not isinstance(
        obs_space.original_space, gym.spaces.Tuple
    ):
        raise ValueError(
            "Obs space must be a Tuple, got {}. Use ".format(obs_space)
            + "MultiAgentEnv.with_agent_groups() to group related "
            "agents for QMix."
        )
    if not isinstance(action_space, gym.spaces.Tuple):
        raise ValueError(
            "Action space must be a Tuple, got {}. ".format(action_space)
            + "Use MultiAgentEnv.with_agent_groups() to group related "
            "agents for QMix."
        )
    if not isinstance(action_space.spaces[0], gym.spaces.Discrete):
        raise ValueError(
            "QMix requires a discrete action space, got {}".format(
                action_space.spaces[0]
            )
        )
    if len({str(x) for x in obs_space.original_space.spaces}) > 1:
        raise ValueError(
            "Implementation limitation: observations of grouped agents "
            "must be homogeneous, got {}".format(obs_space.original_space.spaces)
        )
    if len({str(x) for x in action_space.spaces}) > 1:
        raise ValueError(
            "Implementation limitation: action space of grouped agents "
            "must be homogeneous, got {}".format(action_space.spaces)
        )


def _mac(model, obs, h):
    """Forward pass of the multi-agent controller.

    Args:
        model: TorchModelV2 class
        obs: Tensor of shape [B, n_agents, obs_size]
        h: List of tensors of shape [B, n_agents, h_size]

    Returns:
        q_vals: Tensor of shape [B, n_agents, n_actions]
        h: Tensor of shape [B, n_agents, h_size]
    """
    B, n_agents = obs.size(0), obs.size(1)
    if not isinstance(obs, dict):
        obs = {"obs": obs}
    obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()}
    h_flat = [s.reshape([B * n_agents, -1]) for s in h]
    q_flat, h_flat = model(obs_agents_as_batches, h_flat, None)
    return q_flat.reshape([B, n_agents, -1]), [
        s.reshape([B, n_agents, -1]) for s in h_flat
    ]


def _unroll_mac(model, obs_tensor):
    """Computes the estimated Q values for an entire trajectory batch"""
    B = obs_tensor.size(0)
    T = obs_tensor.size(1)
    n_agents = obs_tensor.size(2)

    mac_out = []
    h = [s.expand([B, n_agents, -1]) for s in model.get_initial_state()]
    for t in range(T):
        q, h = _mac(model, obs_tensor[:, t], h)
        mac_out.append(q)
    mac_out = torch.stack(mac_out, dim=1)  # Concat over time

    return mac_out


def _drop_agent_dim(T):
    shape = list(T.shape)
    B, n_agents = shape[0], shape[1]
    return T.reshape([B * n_agents] + shape[2:])


def _add_agent_dim(T, n_agents):
    shape = list(T.shape)
    B = shape[0] // n_agents
    assert shape[0] % n_agents == 0
    return T.reshape([B, n_agents] + shape[1:])