mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* fix yaml bug * add ext agent * gpus * update * tuning * docs * Sun Oct 15 21:09:25 PDT 2017 * lint * update * Sun Oct 15 22:39:55 PDT 2017 * Sun Oct 15 22:40:17 PDT 2017 * Sun Oct 15 22:43:06 PDT 2017 * Sun Oct 15 22:46:06 PDT 2017 * Sun Oct 15 22:46:21 PDT 2017 * Sun Oct 15 22:48:11 PDT 2017 * Sun Oct 15 22:48:44 PDT 2017 * Sun Oct 15 22:49:23 PDT 2017 * Sun Oct 15 22:50:21 PDT 2017 * Sun Oct 15 22:53:00 PDT 2017 * Sun Oct 15 22:53:34 PDT 2017 * Sun Oct 15 22:54:33 PDT 2017 * Sun Oct 15 22:54:50 PDT 2017 * Sun Oct 15 22:55:20 PDT 2017 * Sun Oct 15 22:56:56 PDT 2017 * Sun Oct 15 22:59:03 PDT 2017 * fix * Update tune_mnist_ray.py * remove script trial * fix * reorder * fix ex * py2 support * upd * comments * comments * cleanup readme * fix trial * annotate * Update rllib.rst * init pausing * Docs, Lint * fix danglings and restore endpoint moved to trialrunner * renaming * nit * start always starts from checkpoint * smalls * nits * lint * last change
260 lines
8.6 KiB
Python
260 lines
8.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import unittest
|
|
import os
|
|
|
|
import ray
|
|
from ray.tune.trial import Trial, Resources
|
|
from ray.tune.trial_runner import TrialRunner
|
|
from ray.tune.config_parser import parse_to_trials
|
|
|
|
|
|
class ConfigParserTest(unittest.TestCase):
|
|
def testParseToTrials(self):
|
|
trials = parse_to_trials({
|
|
"tune-pong": {
|
|
"env": "Pong-v0",
|
|
"alg": "PPO",
|
|
"num_trials": 2,
|
|
"config": {
|
|
"foo": "bar"
|
|
},
|
|
},
|
|
})
|
|
self.assertEqual(len(trials), 2)
|
|
self.assertEqual(trials[0].env_name, "Pong-v0")
|
|
self.assertEqual(trials[0].config, {"foo": "bar"})
|
|
self.assertEqual(trials[0].alg, "PPO")
|
|
self.assertEqual(trials[0].experiment_tag, "0")
|
|
self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong")
|
|
self.assertEqual(trials[1].experiment_tag, "1")
|
|
|
|
def testEval(self):
|
|
trials = parse_to_trials({
|
|
"tune-pong": {
|
|
"env": "Pong-v0",
|
|
"config": {
|
|
"foo": {
|
|
"eval": "2 + 2"
|
|
},
|
|
},
|
|
},
|
|
})
|
|
self.assertEqual(len(trials), 1)
|
|
self.assertEqual(trials[0].config, {"foo": 4})
|
|
self.assertEqual(trials[0].experiment_tag, "0_foo=4")
|
|
|
|
def testGridSearch(self):
|
|
trials = parse_to_trials({
|
|
"tune-pong": {
|
|
"env": "Pong-v0",
|
|
"num_trials": 6,
|
|
"config": {
|
|
"bar": {
|
|
"grid_search": [True, False]
|
|
},
|
|
"foo": {
|
|
"grid_search": [1, 2, 3]
|
|
},
|
|
},
|
|
},
|
|
})
|
|
self.assertEqual(len(trials), 6)
|
|
self.assertEqual(trials[0].config, {"bar": True, "foo": 1})
|
|
self.assertEqual(trials[0].experiment_tag, "0_bar=True_foo=1")
|
|
self.assertEqual(trials[1].config, {"bar": False, "foo": 1})
|
|
self.assertEqual(trials[1].experiment_tag, "1_bar=False_foo=1")
|
|
self.assertEqual(trials[2].config, {"bar": True, "foo": 2})
|
|
self.assertEqual(trials[3].config, {"bar": False, "foo": 2})
|
|
self.assertEqual(trials[4].config, {"bar": True, "foo": 3})
|
|
self.assertEqual(trials[5].config, {"bar": False, "foo": 3})
|
|
|
|
def testGridSearchAndEval(self):
|
|
trials = parse_to_trials({
|
|
"tune-pong": {
|
|
"env": "Pong-v0",
|
|
"num_trials": 1,
|
|
"config": {
|
|
"qux": {
|
|
"eval": "2 + 2"
|
|
},
|
|
"bar": {
|
|
"grid_search": [True, False]
|
|
},
|
|
"foo": {
|
|
"grid_search": [1, 2, 3]
|
|
},
|
|
},
|
|
},
|
|
})
|
|
self.assertEqual(len(trials), 1)
|
|
self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4})
|
|
self.assertEqual(trials[0].experiment_tag, "0_bar=True_foo=1_qux=4")
|
|
|
|
|
|
class TrialRunnerTest(unittest.TestCase):
|
|
def tearDown(self):
|
|
ray.worker.cleanup()
|
|
|
|
def testTrialStatus(self):
|
|
ray.init()
|
|
trial = Trial("CartPole-v0", "__fake")
|
|
self.assertEqual(trial.status, Trial.PENDING)
|
|
trial.start()
|
|
self.assertEqual(trial.status, Trial.RUNNING)
|
|
trial.stop()
|
|
self.assertEqual(trial.status, Trial.TERMINATED)
|
|
trial.stop(error=True)
|
|
self.assertEqual(trial.status, Trial.ERROR)
|
|
|
|
def testTrialErrorOnStart(self):
|
|
ray.init()
|
|
trial = Trial("CartPole-v0", "asdf")
|
|
try:
|
|
trial.start()
|
|
except Exception as e:
|
|
self.assertIn("Unknown algorithm", str(e))
|
|
|
|
def testResourceScheduler(self):
|
|
ray.init(num_cpus=4, num_gpus=1)
|
|
runner = TrialRunner()
|
|
kwargs = {
|
|
"stopping_criterion": {"training_iteration": 1},
|
|
"resources": Resources(cpu=1, gpu=1),
|
|
}
|
|
trials = [
|
|
Trial("CartPole-v0", "__fake", **kwargs),
|
|
Trial("CartPole-v0", "__fake", **kwargs)]
|
|
for t in trials:
|
|
runner.add_trial(t)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
self.assertEqual(trials[1].status, Trial.TERMINATED)
|
|
|
|
def testMultiStepRun(self):
|
|
ray.init(num_cpus=4, num_gpus=2)
|
|
runner = TrialRunner()
|
|
kwargs = {
|
|
"stopping_criterion": {"training_iteration": 5},
|
|
"resources": Resources(cpu=1, gpu=1),
|
|
}
|
|
trials = [
|
|
Trial("CartPole-v0", "__fake", **kwargs),
|
|
Trial("CartPole-v0", "__fake", **kwargs)]
|
|
for t in trials:
|
|
runner.add_trial(t)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
|
|
|
def testErrorHandling(self):
|
|
ray.init(num_cpus=4, num_gpus=2)
|
|
runner = TrialRunner()
|
|
kwargs = {
|
|
"stopping_criterion": {"training_iteration": 1},
|
|
"resources": Resources(cpu=1, gpu=1),
|
|
}
|
|
trials = [
|
|
Trial("CartPole-v0", "asdf", **kwargs),
|
|
Trial("CartPole-v0", "__fake", **kwargs)]
|
|
for t in trials:
|
|
runner.add_trial(t)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.ERROR)
|
|
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.ERROR)
|
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
|
|
|
def testCheckpointing(self):
|
|
ray.init(num_cpus=1, num_gpus=1)
|
|
runner = TrialRunner()
|
|
kwargs = {
|
|
"stopping_criterion": {"training_iteration": 1},
|
|
"resources": Resources(cpu=1, gpu=1),
|
|
}
|
|
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
|
|
trials = runner.get_trials()
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1)
|
|
|
|
path = trials[0].checkpoint()
|
|
kwargs["restore_path"] = path
|
|
|
|
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
|
|
trials = runner.get_trials()
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
self.assertEqual(trials[1].status, Trial.PENDING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
self.assertEqual(trials[1].status, Trial.RUNNING)
|
|
self.assertEqual(ray.get(trials[1].agent.get_info.remote()), 1)
|
|
self.addCleanup(os.remove, path)
|
|
|
|
def testPauseThenResume(self):
|
|
ray.init(num_cpus=1, num_gpus=1)
|
|
runner = TrialRunner()
|
|
kwargs = {
|
|
"stopping_criterion": {"training_iteration": 2},
|
|
"resources": Resources(cpu=1, gpu=1),
|
|
}
|
|
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
|
|
trials = runner.get_trials()
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(ray.get(trials[0].agent.get_info.remote()), None)
|
|
|
|
self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1)
|
|
|
|
trials[0].pause()
|
|
self.assertEqual(trials[0].status, Trial.PAUSED)
|
|
|
|
trials[0].resume()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
|
self.assertEqual(ray.get(trials[0].agent.get_info.remote()), 1)
|
|
|
|
runner.step()
|
|
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|