import logging

import ray
from ray.rllib.agents.ppo.ppo_tf_policy import (
    vf_preds_fetches,
    compute_and_clip_gradients,
    setup_config,
    ValueNetworkMixin,
)
from ray.rllib.evaluation.postprocessing import (
    compute_gae_for_sample_batch,
    Postprocessing,
)
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf

tf1, tf, tfv = try_import_tf()

logger = logging.getLogger(__name__)


def PPOLoss(
    dist_class,
    actions,
    curr_logits,
    behaviour_logits,
    advantages,
    value_fn,
    value_targets,
    vf_preds,
    cur_kl_coeff,
    entropy_coeff,
    clip_param,
    vf_clip_param,
    vf_loss_coeff,
    clip_loss=False,
):
    def surrogate_loss(
        actions, curr_dist, prev_dist, advantages, clip_param, clip_loss
    ):
        pi_new_logp = curr_dist.logp(actions)
        pi_old_logp = prev_dist.logp(actions)

        logp_ratio = tf.math.exp(pi_new_logp - pi_old_logp)
        if clip_loss:
            return tf.minimum(
                advantages * logp_ratio,
                advantages
                * tf.clip_by_value(logp_ratio, 1 - clip_param, 1 + clip_param),
            )
        return advantages * logp_ratio

    def kl_loss(curr_dist, prev_dist):
        return prev_dist.kl(curr_dist)

    def entropy_loss(dist):
        return dist.entropy()

    def vf_loss(value_fn, value_targets, vf_preds, vf_clip_param=0.1):
        # GAE Value Function Loss
        vf_loss1 = tf.math.square(value_fn - value_targets)
        vf_clipped = vf_preds + tf.clip_by_value(
            value_fn - vf_preds, -vf_clip_param, vf_clip_param
        )
        vf_loss2 = tf.math.square(vf_clipped - value_targets)
        vf_loss = tf.maximum(vf_loss1, vf_loss2)
        return vf_loss

    pi_new_dist = dist_class(curr_logits, None)
    pi_old_dist = dist_class(behaviour_logits, None)

    surr_loss = tf.reduce_mean(
        surrogate_loss(
            actions, pi_new_dist, pi_old_dist, advantages, clip_param, clip_loss
        )
    )
    kl_loss = tf.reduce_mean(kl_loss(pi_new_dist, pi_old_dist))
    vf_loss = tf.reduce_mean(vf_loss(value_fn, value_targets, vf_preds, vf_clip_param))
    entropy_loss = tf.reduce_mean(entropy_loss(pi_new_dist))

    total_loss = -surr_loss + cur_kl_coeff * kl_loss
    total_loss += vf_loss_coeff * vf_loss - entropy_coeff * entropy_loss
    return total_loss, surr_loss, kl_loss, vf_loss, entropy_loss


# This is the computation graph for workers (inner adaptation steps)
class WorkerLoss(object):
    def __init__(
        self,
        dist_class,
        actions,
        curr_logits,
        behaviour_logits,
        advantages,
        value_fn,
        value_targets,
        vf_preds,
        cur_kl_coeff,
        entropy_coeff,
        clip_param,
        vf_clip_param,
        vf_loss_coeff,
        clip_loss=False,
    ):
        self.loss, surr_loss, kl_loss, vf_loss, ent_loss = PPOLoss(
            dist_class=dist_class,
            actions=actions,
            curr_logits=curr_logits,
            behaviour_logits=behaviour_logits,
            advantages=advantages,
            value_fn=value_fn,
            value_targets=value_targets,
            vf_preds=vf_preds,
            cur_kl_coeff=cur_kl_coeff,
            entropy_coeff=entropy_coeff,
            clip_param=clip_param,
            vf_clip_param=vf_clip_param,
            vf_loss_coeff=vf_loss_coeff,
            clip_loss=clip_loss,
        )
        self.loss = tf1.Print(self.loss, ["Worker Adapt Loss", self.loss])


# This is the Meta-Update computation graph for main (meta-update step)
class MAMLLoss(object):
    def __init__(
        self,
        model,
        config,
        dist_class,
        value_targets,
        advantages,
        actions,
        behaviour_logits,
        vf_preds,
        cur_kl_coeff,
        policy_vars,
        obs,
        num_tasks,
        split,
        inner_adaptation_steps=1,
        entropy_coeff=0,
        clip_param=0.3,
        vf_clip_param=0.1,
        vf_loss_coeff=1.0,
        use_gae=True,
    ):

        self.config = config
        self.num_tasks = num_tasks
        self.inner_adaptation_steps = inner_adaptation_steps
        self.clip_param = clip_param
        self.dist_class = dist_class
        self.cur_kl_coeff = cur_kl_coeff

        # Split episode tensors into [inner_adaptation_steps+1, num_tasks, -1]
        self.obs = self.split_placeholders(obs, split)
        self.actions = self.split_placeholders(actions, split)
        self.behaviour_logits = self.split_placeholders(behaviour_logits, split)
        self.advantages = self.split_placeholders(advantages, split)
        self.value_targets = self.split_placeholders(value_targets, split)
        self.vf_preds = self.split_placeholders(vf_preds, split)

        #  Construct name to tensor dictionary for easier indexing
        self.policy_vars = {}
        for var in policy_vars:
            self.policy_vars[var.name] = var

        # Calculate pi_new for PPO
        pi_new_logits, current_policy_vars, value_fns = [], [], []
        for i in range(self.num_tasks):
            pi_new, value_fn = self.feed_forward(
                self.obs[0][i], self.policy_vars, policy_config=config["model"]
            )
            pi_new_logits.append(pi_new)
            value_fns.append(value_fn)
            current_policy_vars.append(self.policy_vars)

        inner_kls = []
        inner_ppo_loss = []

        # Recompute weights for inner-adaptation (same weights as workers)
        for step in range(self.inner_adaptation_steps):
            kls = []
            for i in range(self.num_tasks):
                # PPO Loss Function (only Surrogate)
                ppo_loss, _, kl_loss, _, _ = PPOLoss(
                    dist_class=dist_class,
                    actions=self.actions[step][i],
                    curr_logits=pi_new_logits[i],
                    behaviour_logits=self.behaviour_logits[step][i],
                    advantages=self.advantages[step][i],
                    value_fn=value_fns[i],
                    value_targets=self.value_targets[step][i],
                    vf_preds=self.vf_preds[step][i],
                    cur_kl_coeff=0.0,
                    entropy_coeff=entropy_coeff,
                    clip_param=clip_param,
                    vf_clip_param=vf_clip_param,
                    vf_loss_coeff=vf_loss_coeff,
                    clip_loss=False,
                )
                adapted_policy_vars = self.compute_updated_variables(
                    ppo_loss, current_policy_vars[i]
                )
                pi_new_logits[i], value_fns[i] = self.feed_forward(
                    self.obs[step + 1][i],
                    adapted_policy_vars,
                    policy_config=config["model"],
                )
                current_policy_vars[i] = adapted_policy_vars
                kls.append(kl_loss)
                inner_ppo_loss.append(ppo_loss)

            self.kls = kls
            inner_kls.append(kls)

        mean_inner_kl = tf.stack(
            [tf.reduce_mean(tf.stack(inner_kl)) for inner_kl in inner_kls]
        )
        self.mean_inner_kl = mean_inner_kl

        ppo_obj = []
        for i in range(self.num_tasks):
            ppo_loss, surr_loss, kl_loss, val_loss, entropy_loss = PPOLoss(
                dist_class=dist_class,
                actions=self.actions[self.inner_adaptation_steps][i],
                curr_logits=pi_new_logits[i],
                behaviour_logits=self.behaviour_logits[self.inner_adaptation_steps][i],
                advantages=self.advantages[self.inner_adaptation_steps][i],
                value_fn=value_fns[i],
                value_targets=self.value_targets[self.inner_adaptation_steps][i],
                vf_preds=self.vf_preds[self.inner_adaptation_steps][i],
                cur_kl_coeff=0.0,
                entropy_coeff=entropy_coeff,
                clip_param=clip_param,
                vf_clip_param=vf_clip_param,
                vf_loss_coeff=vf_loss_coeff,
                clip_loss=True,
            )
            ppo_obj.append(ppo_loss)
        self.mean_policy_loss = surr_loss
        self.mean_kl = kl_loss
        self.mean_vf_loss = val_loss
        self.mean_entropy = entropy_loss
        self.inner_kl_loss = tf.reduce_mean(
            tf.multiply(self.cur_kl_coeff, mean_inner_kl)
        )
        self.loss = tf.reduce_mean(tf.stack(ppo_obj, axis=0)) + self.inner_kl_loss
        self.loss = tf1.Print(
            self.loss, ["Meta-Loss", self.loss, "Inner KL", self.mean_inner_kl]
        )

    def feed_forward(self, obs, policy_vars, policy_config):
        # Hacky for now, reconstruct FC network with adapted weights
        # @mluo: TODO for any network
        def fc_network(
            inp, network_vars, hidden_nonlinearity, output_nonlinearity, policy_config
        ):
            bias_added = False
            x = inp
            for name, param in network_vars.items():
                if "kernel" in name:
                    x = tf.matmul(x, param)
                elif "bias" in name:
                    x = tf.add(x, param)
                    bias_added = True
                else:
                    raise NameError

                if bias_added:
                    if "out" not in name:
                        x = hidden_nonlinearity(x)
                    elif "out" in name:
                        x = output_nonlinearity(x)
                    else:
                        raise NameError
                    bias_added = False
            return x

        policyn_vars = {}
        valuen_vars = {}
        log_std = None
        for name, param in policy_vars.items():
            if "value" in name:
                valuen_vars[name] = param
            elif "log_std" in name:
                log_std = param
            else:
                policyn_vars[name] = param

        output_nonlinearity = tf.identity
        hidden_nonlinearity = get_activation_fn(policy_config["fcnet_activation"])

        pi_new_logits = fc_network(
            obs, policyn_vars, hidden_nonlinearity, output_nonlinearity, policy_config
        )
        if log_std is not None:
            pi_new_logits = tf.concat([pi_new_logits, 0.0 * pi_new_logits + log_std], 1)
        value_fn = fc_network(
            obs, valuen_vars, hidden_nonlinearity, output_nonlinearity, policy_config
        )

        return pi_new_logits, tf.reshape(value_fn, [-1])

    def compute_updated_variables(self, loss, network_vars):
        grad = tf.gradients(loss, list(network_vars.values()))
        adapted_vars = {}
        for i, tup in enumerate(network_vars.items()):
            name, var = tup
            if grad[i] is None:
                adapted_vars[name] = var
            else:
                adapted_vars[name] = var - self.config["inner_lr"] * grad[i]
        return adapted_vars

    def split_placeholders(self, placeholder, split):
        inner_placeholder_list = tf.split(
            placeholder, tf.math.reduce_sum(split, axis=1), axis=0
        )
        placeholder_list = []
        for index, split_placeholder in enumerate(inner_placeholder_list):
            placeholder_list.append(tf.split(split_placeholder, split[index], axis=0))
        return placeholder_list


def maml_loss(policy, model, dist_class, train_batch):
    logits, state = model(train_batch)
    policy.cur_lr = policy.config["lr"]

    if policy.config["worker_index"]:
        policy.loss_obj = WorkerLoss(
            dist_class=dist_class,
            actions=train_batch[SampleBatch.ACTIONS],
            curr_logits=logits,
            behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
            advantages=train_batch[Postprocessing.ADVANTAGES],
            value_fn=model.value_function(),
            value_targets=train_batch[Postprocessing.VALUE_TARGETS],
            vf_preds=train_batch[SampleBatch.VF_PREDS],
            cur_kl_coeff=0.0,
            entropy_coeff=policy.config["entropy_coeff"],
            clip_param=policy.config["clip_param"],
            vf_clip_param=policy.config["vf_clip_param"],
            vf_loss_coeff=policy.config["vf_loss_coeff"],
            clip_loss=False,
        )
    else:
        policy.var_list = tf1.get_collection(
            tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name
        )
        policy.loss_obj = MAMLLoss(
            model=model,
            dist_class=dist_class,
            value_targets=train_batch[Postprocessing.VALUE_TARGETS],
            advantages=train_batch[Postprocessing.ADVANTAGES],
            actions=train_batch[SampleBatch.ACTIONS],
            behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
            vf_preds=train_batch[SampleBatch.VF_PREDS],
            cur_kl_coeff=policy.kl_coeff,
            policy_vars=policy.var_list,
            obs=train_batch[SampleBatch.CUR_OBS],
            num_tasks=policy.config["num_workers"],
            split=train_batch["split"],
            config=policy.config,
            inner_adaptation_steps=policy.config["inner_adaptation_steps"],
            entropy_coeff=policy.config["entropy_coeff"],
            clip_param=policy.config["clip_param"],
            vf_clip_param=policy.config["vf_clip_param"],
            vf_loss_coeff=policy.config["vf_loss_coeff"],
            use_gae=policy.config["use_gae"],
        )

    return policy.loss_obj.loss


def maml_stats(policy, train_batch):
    if policy.config["worker_index"]:
        return {"worker_loss": policy.loss_obj.loss}
    else:
        return {
            "cur_kl_coeff": tf.cast(policy.kl_coeff, tf.float64),
            "cur_lr": tf.cast(policy.cur_lr, tf.float64),
            "total_loss": policy.loss_obj.loss,
            "policy_loss": policy.loss_obj.mean_policy_loss,
            "vf_loss": policy.loss_obj.mean_vf_loss,
            "kl": policy.loss_obj.mean_kl,
            "inner_kl": policy.loss_obj.mean_inner_kl,
            "entropy": policy.loss_obj.mean_entropy,
        }


class KLCoeffMixin:
    def __init__(self, config):
        self.kl_coeff_val = [config["kl_coeff"]] * config["inner_adaptation_steps"]
        self.kl_target = self.config["kl_target"]
        self.kl_coeff = tf1.get_variable(
            initializer=tf.keras.initializers.Constant(self.kl_coeff_val),
            name="kl_coeff",
            shape=(config["inner_adaptation_steps"]),
            trainable=False,
            dtype=tf.float32,
        )

    def update_kls(self, sampled_kls):
        for i, kl in enumerate(sampled_kls):
            if kl < self.kl_target / 1.5:
                self.kl_coeff_val[i] *= 0.5
            elif kl > 1.5 * self.kl_target:
                self.kl_coeff_val[i] *= 2.0
        print(self.kl_coeff_val)
        self.kl_coeff.load(self.kl_coeff_val, session=self.get_session())
        return self.kl_coeff_val


def maml_optimizer_fn(policy, config):
    """
    Workers use simple SGD for inner adaptation
    Meta-Policy uses Adam optimizer for meta-update
    """
    if not config["worker_index"]:
        return tf1.train.AdamOptimizer(learning_rate=config["lr"])
    return tf1.train.GradientDescentOptimizer(learning_rate=config["inner_lr"])


def setup_mixins(policy, obs_space, action_space, config):
    ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
    KLCoeffMixin.__init__(policy, config)

    # Create the `split` placeholder.
    policy._loss_input_dict["split"] = tf1.placeholder(
        tf.int32,
        name="Meta-Update-Splitting",
        shape=(
            policy.config["inner_adaptation_steps"] + 1,
            policy.config["num_workers"],
        ),
    )


MAMLTFPolicy = build_tf_policy(
    name="MAMLTFPolicy",
    get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
    loss_fn=maml_loss,
    stats_fn=maml_stats,
    optimizer_fn=maml_optimizer_fn,
    extra_action_out_fn=vf_preds_fetches,
    postprocess_fn=compute_gae_for_sample_batch,
    compute_gradients_fn=compute_and_clip_gradients,
    before_init=setup_config,
    before_loss_init=setup_mixins,
    mixins=[KLCoeffMixin],
)