import unittest

import ray
import ray.rllib.algorithms.a2c as a2c
from ray.rllib.utils.test_utils import (
    check_compute_single_action,
    check_train_results,
    framework_iterator,
)


class TestA2C(unittest.TestCase):
    """Sanity tests for A2C exec impl."""

    def setUp(self):
        ray.init(num_cpus=4)

    def tearDown(self):
        ray.shutdown()

    def test_a2c_compilation(self):
        """Test whether an A2C can be built with both frameworks."""
        config = a2c.A2CConfig().rollouts(num_rollout_workers=2, num_envs_per_worker=2)

        num_iterations = 1

        # Test against all frameworks.
        for _ in framework_iterator(config, with_eager_tracing=True):
            for env in ["CartPole-v0", "Pendulum-v1", "PongDeterministic-v0"]:
                trainer = config.build(env=env)
                for i in range(num_iterations):
                    results = trainer.train()
                    check_train_results(results)
                    print(results)
                check_compute_single_action(trainer)
                trainer.stop()

    def test_a2c_exec_impl(self):
        config = (
            a2c.A2CConfig()
            .environment(env="CartPole-v0")
            .reporting(min_time_s_per_iteration=0)
        )

        for _ in framework_iterator(config):
            trainer = config.build()
            results = trainer.train()
            check_train_results(results)
            print(results)
            check_compute_single_action(trainer)
            trainer.stop()

    def test_a2c_exec_impl_microbatch(self):
        config = (
            a2c.A2CConfig()
            .environment(env="CartPole-v0")
            .reporting(min_time_s_per_iteration=0)
            .training(microbatch_size=10)
        )

        for _ in framework_iterator(config):
            trainer = config.build()
            results = trainer.train()
            check_train_results(results)
            print(results)
            check_compute_single_action(trainer)
            trainer.stop()


if __name__ == "__main__":
    import pytest
    import sys

    sys.exit(pytest.main(["-v", __file__]))