mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.agents.a3c import A2CTrainer
|
|
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
|
|
class TestDistributedExecution(unittest.TestCase):
|
|
"""General tests for the distributed execution API."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init(num_cpus=4)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_exec_plan_stats(ray_start_regular):
|
|
for fw in framework_iterator(frameworks=("torch", "tf")):
|
|
trainer = A2CTrainer(
|
|
env="CartPole-v0",
|
|
config={
|
|
"min_time_s_per_reporting": 0,
|
|
"framework": fw,
|
|
"_disable_execution_plan_api": False,
|
|
},
|
|
)
|
|
result = trainer.train()
|
|
assert isinstance(result, dict)
|
|
assert "info" in result
|
|
assert LEARNER_INFO in result["info"]
|
|
assert STEPS_SAMPLED_COUNTER in result["info"]
|
|
assert STEPS_TRAINED_COUNTER in result["info"]
|
|
assert "timers" in result
|
|
assert "learn_time_ms" in result["timers"]
|
|
assert "learn_throughput" in result["timers"]
|
|
assert "sample_time_ms" in result["timers"]
|
|
assert "sample_throughput" in result["timers"]
|
|
assert "update_time_ms" in result["timers"]
|
|
|
|
def test_exec_plan_save_restore(ray_start_regular):
|
|
for fw in framework_iterator(frameworks=("torch", "tf")):
|
|
trainer = A2CTrainer(
|
|
env="CartPole-v0",
|
|
config={
|
|
"min_time_s_per_reporting": 0,
|
|
"framework": fw,
|
|
"_disable_execution_plan_api": False,
|
|
},
|
|
)
|
|
res1 = trainer.train()
|
|
checkpoint = trainer.save()
|
|
for _ in range(2):
|
|
res2 = trainer.train()
|
|
assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2)
|
|
trainer.restore(checkpoint)
|
|
|
|
# Should restore the timesteps counter to the same as res2.
|
|
res3 = trainer.train()
|
|
assert res3["timesteps_total"] < res2["timesteps_total"], (res2, res3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|