"""Tensorflow model for SlateQ"""

from typing import List

import gym
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()


class SlateQTFModel(TFModelV2):
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        fcnet_hiddens_per_candidate=(256, 32),
    ):
        """Initializes a SlateQTFModel instance.

        Each document candidate receives one full Q-value stack, defined by
        `fcnet_hiddens_per_candidate`. The input to each of these Q-value stacks
        is always {[user] concat [document[i]] for i in document_candidates}.

        Extra model kwargs:
            fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
                candidate documents.
        """
        super(SlateQTFModel, self).__init__(
            obs_space, action_space, None, model_config, name
        )

        self.embedding_size = self.obs_space["doc"]["0"].shape[0]
        self.num_candidates = len(self.obs_space["doc"])
        assert self.obs_space["user"].shape[0] == self.embedding_size

        # Setup the Q head output (i.e., model for get_q_values)
        self.user_in = tf.keras.layers.Input(
            shape=(self.embedding_size,), name="user_in"
        )
        self.docs_in = tf.keras.layers.Input(
            shape=(self.embedding_size * self.num_candidates,), name="docs_in"
        )

        self.num_outputs = num_outputs

        q_outs = []
        for i in range(self.num_candidates):
            doc = self.docs_in[
                :, self.embedding_size * i : self.embedding_size * (i + 1)
            ]
            out = tf.keras.layers.concatenate([self.user_in, doc], axis=1)
            for h in fcnet_hiddens_per_candidate:
                out = tf.keras.layers.Dense(h, activation=tf.nn.relu)(out)
            q_value = tf.keras.layers.Dense(1, name=f"q_value_{i}")(out)
            q_outs.append(q_value)
        q_outs = tf.concat(q_outs, axis=1)

        self.q_value_head = tf.keras.Model([self.user_in, self.docs_in], q_outs)

    def get_q_values(self, user: TensorType, docs: List[TensorType]) -> TensorType:
        """Returns Q-values, 1 for each candidate document, given user and doc tensors.

        Args:
            user: [B x u] where u=embedding of user features.
            docs: List[[B x d]] where d=embedding of doc features. Each item in the
                list represents one document candidate.

        Returns:
            Tensor ([batch, num candidates) of Q-values.
            1 Q-value per document candidate.
        """
        return self.q_value_head([user, tf.concat(docs, 1)])