mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00

* Get shared metrics, increment counter & set global vars for remote workers. * Add unit test to test lr_schedule for DDPPO. * Broadcast the local set of global vars to remote workers instead of independently setting the global vars on each rollout worker.
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.ppo as ppo
|
|
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
|
framework_iterator
|
|
|
|
|
|
class TestDDPPO(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_ddppo_compilation(self):
|
|
"""Test whether a DDPPOTrainer can be built with both frameworks."""
|
|
config = ppo.ddppo.DEFAULT_CONFIG.copy()
|
|
config["num_gpus_per_worker"] = 0
|
|
num_iterations = 2
|
|
|
|
for _ in framework_iterator(config, "torch"):
|
|
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
|
|
for i in range(num_iterations):
|
|
trainer.train()
|
|
check_compute_single_action(trainer)
|
|
trainer.stop()
|
|
|
|
def test_ddppo_schedule(self):
|
|
"""Test whether lr_schedule will anneal lr to 0"""
|
|
config = ppo.ddppo.DEFAULT_CONFIG.copy()
|
|
config["num_gpus_per_worker"] = 0
|
|
config["lr_schedule"] = [[0, config["lr"]], [1000, 0.0]]
|
|
num_iterations = 3
|
|
|
|
for _ in framework_iterator(config, "torch"):
|
|
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
|
|
for _ in range(num_iterations):
|
|
result = trainer.train()
|
|
lr = result["info"]["learner"]["default_policy"]["cur_lr"]
|
|
trainer.stop()
|
|
assert lr == 0.0, "lr should anneal to 0.0"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|