import logging
import gym
from typing import Dict, Tuple, List, Optional, Any, Type

import ray
from ray.rllib.algorithms.dqn.dqn_tf_policy import (
    postprocess_nstep_and_prio,
    PRIO_WEIGHTS,
)
from ray.rllib.evaluation import Episode
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import (
    TorchDeterministic,
    TorchDirichlet,
    TorchDistributionWrapper,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.torch_utils import (
    apply_grad_clipping,
    concat_multi_gpu_td_errors,
    huber_loss,
    l2_loss,
)
from ray.rllib.utils.typing import (
    ModelGradients,
    TensorType,
    AlgorithmConfigDict,
)
from ray.rllib.algorithms.ddpg.utils import make_ddpg_models, validate_spaces

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class ComputeTDErrorMixin:
    def __init__(self: TorchPolicyV2):
        def compute_td_error(
            obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights
        ):
            input_dict = self._lazy_tensor_dict(
                SampleBatch(
                    {
                        SampleBatch.CUR_OBS: obs_t,
                        SampleBatch.ACTIONS: act_t,
                        SampleBatch.REWARDS: rew_t,
                        SampleBatch.NEXT_OBS: obs_tp1,
                        SampleBatch.DONES: done_mask,
                        PRIO_WEIGHTS: importance_weights,
                    }
                )
            )
            # Do forward pass on loss to update td errors attribute
            # (one TD-error value per item in batch to update PR weights).
            self.loss(self.model, None, input_dict)

            # `self.model.td_error` is set within actor_critic_loss call.
            return self.model.tower_stats["td_error"]

        self.compute_td_error = compute_td_error


class TargetNetworkMixin:
    """Mixin class adding a method for (soft) target net(s) synchronizations.

    - Adds the `update_target` method to the policy.
      Calling `update_target` updates all target Q-networks' weights from their
      respective "main" Q-metworks, based on tau (smooth, partial updating).
    """

    def __init__(self):
        # Hard initial update from Q-net(s) to target Q-net(s).
        self.update_target(tau=1.0)

    def update_target(self: TorchPolicyV2, tau=None):
        # Update_target_fn will be called periodically to copy Q network to
        # target Q network, using (soft) tau-synching.
        tau = tau or self.config.get("tau")
        model_state_dict = self.model.state_dict()
        # Support partial (soft) synching.
        # If tau == 1.0: Full sync from Q-model to target Q-model.
        target_state_dict = next(iter(self.target_models.values())).state_dict()
        model_state_dict = {
            k: tau * model_state_dict[k] + (1 - tau) * v
            for k, v in target_state_dict.items()
        }

        for target in self.target_models.values():
            target.load_state_dict(model_state_dict)

    @override(TorchPolicyV2)
    def set_weights(self: TorchPolicyV2, weights):
        # Makes sure that whenever we restore weights for this policy's
        # model, we sync the target network (from the main model)
        # at the same time.
        TorchPolicyV2.set_weights(self, weights)
        self.update_target()


class DDPGTorchPolicy(TargetNetworkMixin, ComputeTDErrorMixin, TorchPolicyV2):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: AlgorithmConfigDict,
    ):
        config = dict(ray.rllib.algorithms.ddpg.ddpg.DDPGConfig().to_dict(), **config)

        # Create global step for counting the number of update operations.
        self.global_step = 0

        # Validate action space for DDPG
        validate_spaces(self, observation_space, action_space)

        TorchPolicyV2.__init__(
            self,
            observation_space,
            action_space,
            config,
            max_seq_len=config["model"]["max_seq_len"],
        )

        ComputeTDErrorMixin.__init__(self)

        # TODO: Don't require users to call this manually.
        self._initialize_loss_from_dummy_batch()

        TargetNetworkMixin.__init__(self)

    @override(TorchPolicyV2)
    def make_model_and_action_dist(
        self,
    ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
        model = make_ddpg_models(self)
        if isinstance(self.action_space, Simplex):
            distr_class = TorchDirichlet
        else:
            distr_class = TorchDeterministic
        return model, distr_class

    @override(TorchPolicyV2)
    def optimizer(
        self,
    ) -> List["torch.optim.Optimizer"]:
        """Create separate optimizers for actor & critic losses."""

        # Set epsilons to match tf.keras.optimizers.Adam's epsilon default.
        self._actor_optimizer = torch.optim.Adam(
            params=self.model.policy_variables(), lr=self.config["actor_lr"], eps=1e-7
        )

        self._critic_optimizer = torch.optim.Adam(
            params=self.model.q_variables(), lr=self.config["critic_lr"], eps=1e-7
        )

        # Return them in the same order as the respective loss terms are returned.
        return [self._actor_optimizer, self._critic_optimizer]

    @override(TorchPolicyV2)
    def apply_gradients(self, gradients: ModelGradients) -> None:
        # For policy gradient, update policy net one time v.s.
        # update critic net `policy_delay` time(s).
        if self.global_step % self.config["policy_delay"] == 0:
            self._actor_optimizer.step()

        self._critic_optimizer.step()

        # Increment global step & apply ops.
        self.global_step += 1

    @override(TorchPolicyV2)
    def action_distribution_fn(
        self,
        model: ModelV2,
        *,
        obs_batch: TensorType,
        state_batches: TensorType,
        is_training: bool = False,
        **kwargs
    ) -> Tuple[TensorType, type, List[TensorType]]:
        model_out, _ = model(
            SampleBatch(obs=obs_batch[SampleBatch.CUR_OBS], _is_training=is_training)
        )
        dist_inputs = model.get_policy_output(model_out)

        if isinstance(self.action_space, Simplex):
            distr_class = TorchDirichlet
        else:
            distr_class = TorchDeterministic
        return dist_inputs, distr_class, []  # []=state out

    @override(TorchPolicyV2)
    def postprocess_trajectory(
        self,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
        episode: Optional[Episode] = None,
    ) -> SampleBatch:
        return postprocess_nstep_and_prio(
            self, sample_batch, other_agent_batches, episode
        )

    @override(TorchPolicyV2)
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch,
    ) -> List[TensorType]:
        target_model = self.target_models[model]

        twin_q = self.config["twin_q"]
        gamma = self.config["gamma"]
        n_step = self.config["n_step"]
        use_huber = self.config["use_huber"]
        huber_threshold = self.config["huber_threshold"]
        l2_reg = self.config["l2_reg"]

        input_dict = SampleBatch(
            obs=train_batch[SampleBatch.CUR_OBS], _is_training=True
        )
        input_dict_next = SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True
        )

        model_out_t, _ = model(input_dict, [], None)
        model_out_tp1, _ = model(input_dict_next, [], None)
        target_model_out_tp1, _ = target_model(input_dict_next, [], None)

        # Policy network evaluation.
        # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
        policy_t = model.get_policy_output(model_out_t)
        # policy_batchnorm_update_ops = list(
        #    set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

        policy_tp1 = target_model.get_policy_output(target_model_out_tp1)

        # Action outputs.
        if self.config["smooth_target_policy"]:
            target_noise_clip = self.config["target_noise_clip"]
            clipped_normal_sample = torch.clamp(
                torch.normal(
                    mean=torch.zeros(policy_tp1.size()), std=self.config["target_noise"]
                ).to(policy_tp1.device),
                -target_noise_clip,
                target_noise_clip,
            )

            policy_tp1_smoothed = torch.min(
                torch.max(
                    policy_tp1 + clipped_normal_sample,
                    torch.tensor(
                        self.action_space.low,
                        dtype=torch.float32,
                        device=policy_tp1.device,
                    ),
                ),
                torch.tensor(
                    self.action_space.high,
                    dtype=torch.float32,
                    device=policy_tp1.device,
                ),
            )
        else:
            # No smoothing, just use deterministic actions.
            policy_tp1_smoothed = policy_tp1

        # Q-net(s) evaluation.
        # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
        # Q-values for given actions & observations in given current
        q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])

        # Q-values for current policy (no noise) in given current state
        q_t_det_policy = model.get_q_values(model_out_t, policy_t)

        actor_loss = -torch.mean(q_t_det_policy)

        if twin_q:
            twin_q_t = model.get_twin_q_values(
                model_out_t, train_batch[SampleBatch.ACTIONS]
            )
        # q_batchnorm_update_ops = list(
        #     set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

        # Target q-net(s) evaluation.
        q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed)

        if twin_q:
            twin_q_tp1 = target_model.get_twin_q_values(
                target_model_out_tp1, policy_tp1_smoothed
            )

        q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1)
        if twin_q:
            twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
        q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

        # Compute RHS of bellman equation.
        q_t_selected_target = (
            train_batch[SampleBatch.REWARDS] + gamma ** n_step * q_tp1_best_masked
        ).detach()

        # Compute the error (potentially clipped).
        if twin_q:
            td_error = q_t_selected - q_t_selected_target
            twin_td_error = twin_q_t_selected - q_t_selected_target
            if use_huber:
                errors = huber_loss(td_error, huber_threshold) + huber_loss(
                    twin_td_error, huber_threshold
                )
            else:
                errors = 0.5 * (
                    torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0)
                )
        else:
            td_error = q_t_selected - q_t_selected_target
            if use_huber:
                errors = huber_loss(td_error, huber_threshold)
            else:
                errors = 0.5 * torch.pow(td_error, 2.0)

        critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors)

        # Add l2-regularization if required.
        if l2_reg is not None:
            for name, var in model.policy_variables(as_dict=True).items():
                if "bias" not in name:
                    actor_loss += l2_reg * l2_loss(var)
            for name, var in model.q_variables(as_dict=True).items():
                if "bias" not in name:
                    critic_loss += l2_reg * l2_loss(var)

        # Model self-supervised losses.
        if self.config["use_state_preprocessor"]:
            # Expand input_dict in case custom_loss' need them.
            input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
            input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
            input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
            input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
            [actor_loss, critic_loss] = model.custom_loss(
                [actor_loss, critic_loss], input_dict
            )

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["q_t"] = q_t
        model.tower_stats["actor_loss"] = actor_loss
        model.tower_stats["critic_loss"] = critic_loss
        # TD-error tensor in final stats
        # will be concatenated and retrieved for each individual batch item.
        model.tower_stats["td_error"] = td_error

        # Return two loss terms (corresponding to the two optimizers, we create).
        return [actor_loss, critic_loss]

    @override(TorchPolicyV2)
    def extra_grad_process(
        self, optimizer: torch.optim.Optimizer, loss: TensorType
    ) -> Dict[str, TensorType]:
        # Clip grads if configured.
        return apply_grad_clipping(self, optimizer, loss)

    @override(TorchPolicyV2)
    def extra_compute_grad_fetches(self) -> Dict[str, Any]:
        fetches = convert_to_numpy(concat_multi_gpu_td_errors(self))
        return dict({LEARNER_STATS_KEY: {}}, **fetches)

    @override(TorchPolicyV2)
    def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
        q_t = torch.stack(self.get_tower_stats("q_t"))
        stats = {
            "actor_loss": torch.mean(torch.stack(self.get_tower_stats("actor_loss"))),
            "critic_loss": torch.mean(torch.stack(self.get_tower_stats("critic_loss"))),
            "mean_q": torch.mean(q_t),
            "max_q": torch.max(q_t),
            "min_q": torch.min(q_t),
        }
        return convert_to_numpy(stats)