import argparse
from gym.spaces import Box, Discrete
import numpy as np

from ray.rllib.examples.models.custom_model_api import (
    DuelingQModel,
    TorchDuelingQModel,
    ContActionQModel,
    TorchContActionQModel,
)
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "tfe", "torch"],
    default="tf",
    help="The DL framework specifier.",
)

if __name__ == "__main__":
    args = parser.parse_args()

    # Test API wrapper for dueling Q-head.

    obs_space = Box(-1.0, 1.0, (3,))
    action_space = Discrete(3)

    # Run in eager mode for value checking and debugging.
    tf1.enable_eager_execution()

    # __sphinx_doc_model_construct_1_begin__
    my_dueling_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=action_space.n,
        model_config=MODEL_DEFAULTS,
        framework=args.framework,
        # Providing the `model_interface` arg will make the factory
        # wrap the chosen default model with our new model API class
        # (DuelingQModel). This way, both `forward` and `get_q_values`
        # are available in the returned class.
        model_interface=DuelingQModel
        if args.framework != "torch"
        else TorchDuelingQModel,
        name="dueling_q_model",
    )
    # __sphinx_doc_model_construct_1_end__

    batch_size = 10
    input_ = np.array([obs_space.sample() for _ in range(batch_size)])
    # Note that for PyTorch, you will have to provide torch tensors here.
    if args.framework == "torch":
        input_ = torch.from_numpy(input_)

    input_dict = SampleBatch(obs=input_, _is_training=False)
    out, state_outs = my_dueling_model(input_dict=input_dict)
    assert out.shape == (10, 256)
    # Pass `out` into `get_q_values`
    q_values = my_dueling_model.get_q_values(out)
    assert q_values.shape == (10, action_space.n)

    # Test API wrapper for single value Q-head from obs/action input.

    obs_space = Box(-1.0, 1.0, (3,))
    action_space = Box(-1.0, -1.0, (2,))

    # __sphinx_doc_model_construct_2_begin__
    my_cont_action_q_model = ModelCatalog.get_model_v2(
        obs_space=obs_space,
        action_space=action_space,
        num_outputs=2,
        model_config=MODEL_DEFAULTS,
        framework=args.framework,
        # Providing the `model_interface` arg will make the factory
        # wrap the chosen default model with our new model API class
        # (DuelingQModel). This way, both `forward` and `get_q_values`
        # are available in the returned class.
        model_interface=ContActionQModel
        if args.framework != "torch"
        else TorchContActionQModel,
        name="cont_action_q_model",
    )
    # __sphinx_doc_model_construct_2_end__

    batch_size = 10
    input_ = np.array([obs_space.sample() for _ in range(batch_size)])

    # Note that for PyTorch, you will have to provide torch tensors here.
    if args.framework == "torch":
        input_ = torch.from_numpy(input_)

    input_dict = SampleBatch(obs=input_, _is_training=False)
    # Note that for PyTorch, you will have to provide torch tensors here.
    out, state_outs = my_cont_action_q_model(input_dict=input_dict)
    assert out.shape == (10, 256)
    # Pass `out` and an action into `my_cont_action_q_model`
    action = np.array([action_space.sample() for _ in range(batch_size)])
    if args.framework == "torch":
        action = torch.from_numpy(action)

    q_value = my_cont_action_q_model.get_single_q_value(out, action)
    assert q_value.shape == (10, 1)