from typing import Type

from ray.rllib.agents.sac import SACTrainer, DEFAULT_CONFIG as SAC_DEFAULT_CONFIG
from ray.rllib.agents.sac.rnnsac_torch_policy import RNNSACTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE

DEFAULT_CONFIG = SACTrainer.merge_trainer_configs(
    SAC_DEFAULT_CONFIG,
    {
        # Batch mode (see common config)
        "batch_mode": "complete_episodes",
        # If True, assume a zero-initialized state input (no matter where in
        # the episode the sequence is located).
        # If False, store the initial states along with each SampleBatch, use
        # it (as initial state when running through the network for training),
        # and update that initial state during training (from the internal
        # state outputs of the immediately preceding sequence).
        "zero_init_states": True,
        "replay_buffer_config": {
            # If > 0, use the `burn_in` first steps of each replay-sampled sequence
            # (starting either from all 0.0-values if `zero_init_state=True` or
            # from the already stored values) to calculate an even more accurate
            # initial states for the actual sequence (starting after this burn-in
            # window). In the burn-in case, the actual length of the sequence
            # used for loss calculation is `n - burn_in` time steps
            # (n=LSTM’s/attention net’s max_seq_len).
            "replay_burn_in": 0,
            # Set automatically: The number of contiguous environment steps to
            # replay at once. Will be calculated via
            # model->max_seq_len + burn_in.
            # Do not set this to any valid value!
            "replay_sequence_length": -1,
        },
        "burn_in": DEPRECATED_VALUE,
    },
    _allow_unknown_configs=True,
)


class RNNSACTrainer(SACTrainer):
    @classmethod
    @override(SACTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return DEFAULT_CONFIG

    @override(SACTrainer)
    def validate_config(self, config: TrainerConfigDict) -> None:
        # Call super's validation method.
        super().validate_config(config)

        # Add the `burn_in` to the Model's max_seq_len.
        replay_sequence_length = (
            config["replay_buffer_config"]["replay_burn_in"]
            + config["model"]["max_seq_len"]
        )
        # Check if user tries to set replay_sequence_length (to anything
        # other than the proper value)
        if config["replay_buffer_config"]["replay_sequence_length"] not in [
            -1,
            replay_sequence_length,
        ]:
            raise ValueError(
                "`replay_sequence_length` is calculated automatically to be "
                "config['model']['max_seq_len'] + config['burn_in']. Leave "
                "config['replay_sequence_length'] blank to avoid this error."
            )
        # Set the replay sequence length to the max_seq_len of the model.
        config["replay_buffer_config"][
            "replay_sequence_length"
        ] = replay_sequence_length

        if config["framework"] != "torch":
            raise ValueError(
                "Only `framework=torch` supported so far for RNNSACTrainer!"
            )

    @override(SACTrainer)
    def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
        return RNNSACTorchPolicy