ray/rllib/algorithms/crr/tests/test_crr.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

93 lines
2.9 KiB
Python
Raw Normal View History

from pathlib import Path
import os
import unittest
import ray
from ray.rllib.algorithms.crr import CRRConfig
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
from ray.rllib.utils.test_utils import (
check_compute_single_action,
check_train_results,
)
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class TestCRR(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_crr_compilation(self):
"""Test whether a CRR algorithm can be built with all supported frameworks."""
# TODO: terrible asset management style
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/pendulum/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
config = (
CRRConfig()
.environment(env="Pendulum-v1", clip_actions=True)
.framework("torch")
.offline_data(input_=[data_file], actions_in_input_normalized=True)
.training(
twin_q=True,
train_batch_size=256,
replay_buffer_config={
"type": MultiAgentReplayBuffer,
"learning_starts": 0,
"capacity": 100000,
},
weight_type="bin",
advantage_type="mean",
n_action_sample=4,
target_update_grad_intervals=10000,
tau=1.0,
)
.evaluation(
evaluation_interval=2,
evaluation_num_workers=2,
evaluation_duration=10,
evaluation_duration_unit="episodes",
evaluation_parallel_to_training=True,
evaluation_config={"input": "sampler", "explore": False},
)
.rollouts(num_rollout_workers=0)
)
num_iterations = 4
for _ in ["torch"]:
algorithm = config.build()
# check if 4 iterations raises any errors
for i in range(num_iterations):
results = algorithm.train()
check_train_results(results)
print(results)
if (i + 1) % 2 == 0:
# evaluation happens every 2 iterations
eval_results = results["evaluation"]
print(
f"iter={algorithm.iteration} "
f"R={eval_results['episode_reward_mean']}"
)
check_compute_single_action(algorithm)
algorithm.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))