mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.a3c as a3c
|
|
import ray.rllib.agents.dqn as dqn
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
|
|
class TestTrainerEvaluation(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_evaluation_option(self):
|
|
config = dqn.DEFAULT_CONFIG.copy()
|
|
config.update({
|
|
"env": "CartPole-v0",
|
|
"evaluation_interval": 2,
|
|
"evaluation_num_episodes": 2,
|
|
"evaluation_config": {
|
|
"gamma": 0.98,
|
|
}
|
|
})
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
|
agent = dqn.DQNTrainer(config=config)
|
|
# Given evaluation_interval=2, r0, r2, r4 should not contain
|
|
# evaluation metrics, while r1, r3 should.
|
|
r0 = agent.train()
|
|
print(r0)
|
|
r1 = agent.train()
|
|
print(r1)
|
|
r2 = agent.train()
|
|
print(r2)
|
|
r3 = agent.train()
|
|
print(r3)
|
|
agent.stop()
|
|
|
|
self.assertFalse("evaluation" in r0)
|
|
self.assertTrue("evaluation" in r1)
|
|
self.assertFalse("evaluation" in r2)
|
|
self.assertTrue("evaluation" in r3)
|
|
self.assertTrue("episode_reward_mean" in r1["evaluation"])
|
|
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
|
|
|
|
def test_evaluation_wo_evaluation_worker_set(self):
|
|
config = a3c.DEFAULT_CONFIG.copy()
|
|
config.update({
|
|
"env": "CartPole-v0",
|
|
# Switch off evaluation (this should already be the default).
|
|
"evaluation_interval": None,
|
|
})
|
|
for _ in framework_iterator(frameworks=("tf", "torch")):
|
|
# Setup trainer w/o evaluation worker set and still call
|
|
# evaluate() -> Expect error.
|
|
agent_wo_env_on_driver = a3c.A3CTrainer(config=config)
|
|
self.assertRaisesRegexp(
|
|
ValueError, "Cannot evaluate w/o an evaluation worker set",
|
|
agent_wo_env_on_driver.evaluate)
|
|
agent_wo_env_on_driver.stop()
|
|
|
|
# Try again using `create_env_on_driver=True`.
|
|
# This force-adds the env on the local-worker, so this Trainer
|
|
# can `evaluate` even though, it doesn't have an evaluation-worker
|
|
# set.
|
|
config["create_env_on_driver"] = True
|
|
agent_w_env_on_driver = a3c.A3CTrainer(config=config)
|
|
results = agent_w_env_on_driver.evaluate()
|
|
assert "evaluation" in results
|
|
assert "episode_reward_mean" in results["evaluation"]
|
|
agent_w_env_on_driver.stop()
|
|
config["create_env_on_driver"] = False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|