ray/rllib/algorithms/a2c/tests/test_a2c.py

74 lines
2.1 KiB
Python

import unittest
import ray
import ray.rllib.algorithms.a2c as a2c
from ray.rllib.utils.test_utils import (
check_compute_single_action,
check_train_results,
framework_iterator,
)
class TestA2C(unittest.TestCase):
"""Sanity tests for A2C exec impl."""
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
def test_a2c_compilation(self):
"""Test whether an A2C can be built with both frameworks."""
config = a2c.A2CConfig().rollouts(num_rollout_workers=2, num_envs_per_worker=2)
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
for env in ["CartPole-v0", "Pendulum-v1", "PongDeterministic-v0"]:
trainer = config.build(env=env)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(trainer)
trainer.stop()
def test_a2c_exec_impl(self):
config = (
a2c.A2CConfig()
.environment(env="CartPole-v0")
.reporting(min_time_s_per_iteration=0)
)
for _ in framework_iterator(config):
trainer = config.build()
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(trainer)
trainer.stop()
def test_a2c_exec_impl_microbatch(self):
config = (
a2c.A2CConfig()
.environment(env="CartPole-v0")
.reporting(min_time_s_per_iteration=0)
.training(microbatch_size=10)
)
for _ in framework_iterator(config):
trainer = config.build()
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(trainer)
trainer.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))