ray/rllib/algorithms/alpha_star/tests/test_alpha_star.py

68 lines
2 KiB
Python

import pyspiel
import unittest
import ray
import ray.rllib.algorithms.alpha_star as alpha_star
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.utils.test_utils import (
check_compute_single_action,
check_train_results,
framework_iterator,
)
from ray.tune import register_env
# Connect-4 OpenSpiel env.
register_env("connect_four", lambda _: OpenSpielEnv(pyspiel.load_game("connect_four")))
class TestAlphaStar(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=20)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_alpha_star_compilation(self):
"""Test whether AlphaStar can be built with all frameworks."""
config = (
alpha_star.AlphaStarConfig()
.environment(env="connect_four")
.training(
gamma=1.0,
model={"fcnet_hiddens": [256, 256, 256]},
vf_loss_coeff=0.01,
entropy_coeff=0.004,
league_builder_config={
"win_rate_threshold_for_new_snapshot": 0.8,
"num_random_policies": 2,
"num_learning_league_exploiters": 1,
"num_learning_main_exploiters": 1,
},
grad_clip=10.0,
replay_buffer_capacity=10,
replay_buffer_replay_ratio=0.0,
use_kl_loss=True,
)
.rollouts(num_rollout_workers=4, num_envs_per_worker=5)
.resources(num_gpus=4, _fake_gpus=True)
)
num_iterations = 2
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
print(results)
check_train_results(results)
check_compute_single_action(trainer)
trainer.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))