From 316f9e2bb7602fae4116d278e1fa1a32c92fc050 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 20 Nov 2017 17:52:43 -0800 Subject: [PATCH] [tune] Support user-defined trainable functions / classes / envs with a shared object registry (#1226) --- doc/source/conf.py | 9 + doc/source/example-a3c.rst | 2 +- doc/source/example-evolution-strategies.rst | 4 +- doc/source/example-policy-gradient.rst | 2 +- doc/source/rllib.rst | 4 +- python/ray/rllib/README.rst | 2 +- python/ray/rllib/__init__.py | 19 ++ python/ray/rllib/agent.py | 48 ++-- python/ray/rllib/dqn/dqn.py | 2 +- .../ray/rllib/test/test_checkpoint_restore.py | 4 +- python/ray/rllib/train.py | 45 ++-- .../cartpole-grid-search-example.yaml | 2 +- .../ray/rllib/tuned_examples/hopper-ppo.yaml | 2 +- .../ray/rllib/tuned_examples/humanoid-es.yaml | 2 +- .../tuned_examples/humanoid-ppo-gae.yaml | 2 +- .../rllib/tuned_examples/humanoid-ppo.yaml | 2 +- .../tuned_examples/hyperband-cartpole.yaml | 2 +- python/ray/rllib/tuned_examples/pong-a3c.yaml | 2 +- python/ray/rllib/tuned_examples/pong-dqn.yaml | 4 +- .../rllib/tuned_examples/walker2d-ppo.yaml | 2 +- python/ray/tune/README.rst | 8 +- python/ray/tune/__init__.py | 24 ++ python/ray/tune/config_parser.py | 71 ++++-- python/ray/tune/error.py | 8 + python/ray/tune/examples/tune_mnist_ray.py | 47 ++-- python/ray/tune/examples/tune_mnist_ray.yaml | 1 + python/ray/tune/registry.py | 87 +++++++ python/ray/tune/result.py | 3 + python/ray/tune/script_runner.py | 108 +++++---- python/ray/tune/trainable.py | 57 +++++ python/ray/tune/trial.py | 103 ++++---- python/ray/tune/trial_runner.py | 21 +- python/ray/tune/tune.py | 26 +- python/ray/tune/variant_generator.py | 29 ++- python/ray/tune/visual_utils.py | 4 +- test/jenkins_tests/run_multi_node_tests.sh | 24 +- test/trial_runner_test.py | 226 +++++++++++++++--- test/trial_scheduler_test.py | 30 +-- 38 files changed, 739 insertions(+), 299 deletions(-) create mode 100644 python/ray/tune/error.py create mode 100644 python/ray/tune/registry.py create mode 100644 python/ray/tune/trainable.py diff --git a/doc/source/conf.py b/doc/source/conf.py index c27242a57..9c878782c 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -19,10 +19,19 @@ import shlex # These lines added to enable Sphinx to work without installing Ray. import mock MOCK_MODULES = ["gym", + "gym.spaces", + "scipy", + "scipy.signal", "tensorflow", "tensorflow.contrib", + "tensorflow.contrib.layers", "tensorflow.contrib.slim", "tensorflow.contrib.rnn", + "tensorflow.core", + "tensorflow.core.util", + "tensorflow.python", + "tensorflow.python.client", + "tensorflow.python.util", "pyarrow", "pyarrow.plasma", "smart_open", diff --git a/doc/source/example-a3c.rst b/doc/source/example-a3c.rst index e9f0d4510..38fc9600f 100644 --- a/doc/source/example-a3c.rst +++ b/doc/source/example-a3c.rst @@ -25,7 +25,7 @@ You can run the code with .. code-block:: bash - python/ray/rllib/train.py --env=Pong-ram-v4 --alg=A3C --config='{"num_workers": N}' + python/ray/rllib/train.py --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}' Reinforcement Learning ---------------------- diff --git a/doc/source/example-evolution-strategies.rst b/doc/source/example-evolution-strategies.rst index e2e2fd113..16cdc3126 100644 --- a/doc/source/example-evolution-strategies.rst +++ b/doc/source/example-evolution-strategies.rst @@ -18,7 +18,7 @@ on the ``Humanoid-v1`` gym environment. .. code-block:: bash - python/ray/rllib/train.py --env=Humanoid-v1 --alg=ES + python/ray/rllib/train.py --env=Humanoid-v1 --run=ES To train a policy on a cluster (e.g., using 900 workers), run the following. @@ -26,7 +26,7 @@ To train a policy on a cluster (e.g., using 900 workers), run the following. python ray/python/ray/rllib/train.py \ --env=Humanoid-v1 \ - --alg=ES \ + --run=ES \ --redis-address= \ --config='{"num_workers": 900, "episodes_per_batch": 10000, "timesteps_per_batch": 100000}' diff --git a/doc/source/example-policy-gradient.rst b/doc/source/example-policy-gradient.rst index 313b6080f..a0d98821f 100644 --- a/doc/source/example-policy-gradient.rst +++ b/doc/source/example-policy-gradient.rst @@ -16,7 +16,7 @@ Then you can run the example as follows. .. code-block:: bash - python/ray/rllib/train.py --env=Pong-ram-v4 --alg=PPO + python/ray/rllib/train.py --env=Pong-ram-v4 --run=PPO This will train an agent on the ``Pong-ram-v4`` Atari environment. You can also try passing in the ``Pong-v0`` environment or the ``CartPole-v0`` environment. diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 98aec1523..205de4406 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -30,7 +30,7 @@ You can run training with :: - python ray/python/ray/rllib/train.py --env CartPole-v0 --alg PPO --config '{"timesteps_per_batch": 10000}' + python ray/python/ray/rllib/train.py --env CartPole-v0 --run PPO --config '{"timesteps_per_batch": 10000}' By default, the results will be logged to a subdirectory of ``/tmp/ray``. This subdirectory will contain a file ``config.json`` which contains the @@ -51,7 +51,7 @@ The ``train.py`` script has a number of options you can show by running The most important options are for choosing the environment with ``--env`` (any OpenAI gym environment including ones registered by the user -can be used) and for choosing the algorithm with ``--alg`` +can be used) and for choosing the algorithm with ``-run`` (available options are ``PPO``, ``A3C``, ``ES`` and ``DQN``). Each algorithm has specific hyperparameters that can be set with ``--config``, see the ``DEFAULT_CONFIG`` variable in diff --git a/python/ray/rllib/README.rst b/python/ray/rllib/README.rst index 4834aac03..074185fbb 100644 --- a/python/ray/rllib/README.rst +++ b/python/ray/rllib/README.rst @@ -8,7 +8,7 @@ You can run training with :: - python train.py --env CartPole-v0 --alg PPO + python train.py --env CartPole-v0 --run PPO The available algorithms are: diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index e69de29bb..c6d4b5ab0 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.tune.registry import register_trainable +from ray.rllib import ppo, es, dqn, a3c +from ray.rllib.agent import _MockAgent, _SigmoidFakeData + + +def _register_all(): + register_trainable("PPO", ppo.PPOAgent) + register_trainable("ES", es.ESAgent) + register_trainable("DQN", dqn.DQNAgent) + register_trainable("A3C", a3c.A3CAgent) + register_trainable("__fake", _MockAgent) + register_trainable("__sigmoid_fake_data", _SigmoidFakeData) + + +_register_all() diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 19b3529d1..8ca499b2d 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -17,13 +17,15 @@ import uuid import tensorflow as tf from ray.tune.logger import UnifiedLogger +from ray.tune.registry import ENV_CREATOR from ray.tune.result import TrainingResult +from ray.tune.trainable import Trainable logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class Agent(object): +class Agent(Trainable): """All RLlib agents extend this base class. Agent objects retain internal model state between calls to train(), so @@ -33,39 +35,40 @@ class Agent(object): env_creator (func): Function that creates a new training env. config (obj): Algorithm-specific configuration data. logdir (str): Directory in which training outputs should be placed. + registry (obj): Object registry. """ _allow_unknown_configs = False _default_logdir = "/tmp/ray" def __init__( - self, env_creator, config, logger_creator=None): + self, config={}, env=None, registry=None, logger_creator=None): """Initialize an RLLib agent. Args: - env_creator (str|func): Name of the OpenAI gym environment to train - against, or a function that creates such an env. config (dict): Algorithm-specific configuration data. + env (str): Name of the environment to use. Note that this can also + be specified as the `env` key in config. + registry (obj): Object registry for user-defined envs, models, etc. + If unspecified, it will be assumed empty. logger_creator (func): Function that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ self._initialize_ok = False self._experiment_id = uuid.uuid4().hex - if type(env_creator) is str: - import gym - env_name = env_creator - self.env_creator = lambda: gym.make(env_name) + env = env or config.get("env") + if env: + config["env"] = env + if registry and registry.contains(ENV_CREATOR, env): + self.env_creator = registry.get(ENV_CREATOR, env) else: - if hasattr(env_creator, "env_name"): - env_name = env_creator.env_name - else: - env_name = "custom" - self.env_creator = env_creator - + import gym + self.env_creator = lambda: gym.make(env) self.config = self._default_config.copy() + self.registry = registry if not self._allow_unknown_configs: for k in config.keys(): - if k not in self.config: + if k not in self.config and k != "env": raise Exception( "Unknown agent config `{}`, " "all agent configs: {}".format(k, self.config.keys())) @@ -76,8 +79,7 @@ class Agent(object): self.logdir = self._result_logger.logdir else: logdir_suffix = "{}_{}_{}".format( - env_name, - self._agent_name, + env, self._agent_name, datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) if not os.path.exists(self._default_logdir): os.makedirs(self._default_logdir) @@ -214,7 +216,14 @@ class Agent(object): def stop(self): """Releases all resources used by this agent.""" - self._result_logger.close() + if self._initialize_ok: + self._result_logger.close() + self._stop() + + def _stop(self): + """Subclasses should override this for custom stopping.""" + + pass def compute_action(self, observation): """Computes an action using the current trained policy.""" @@ -336,5 +345,4 @@ def get_agent_class(alg): return _SigmoidFakeData else: raise Exception( - ("Unknown algorithm {}, check --alg argument. Valid choices " + - "are PPO, ES, DQN, and A3C.").format(alg)) + ("Unknown algorithm {}.").format(alg)) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 9f7e43b14..14fc9edaa 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -355,7 +355,7 @@ class DQNAgent(Agent): _agent_name = "DQN" _default_config = DEFAULT_CONFIG - def stop(self): + def _stop(self): for w in self.workers: w.stop.remote() diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index a9a6b9197..1e782485e 100755 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -29,8 +29,8 @@ CONFIGS = { def test(use_object_store, alg_name): cls = get_agent_class(alg_name) - alg1 = cls("CartPole-v0", CONFIGS[name]) - alg2 = cls("CartPole-v0", CONFIGS[name]) + alg1 = cls(config=CONFIGS[name], env="CartPole-v0") + alg2 = cls(config=CONFIGS[name], env="CartPole-v0") for _ in range(3): res = alg1.train() diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 17cf2f43d..bb4625a93 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -9,12 +9,12 @@ import sys import yaml from ray.tune.config_parser import make_parser, resources_to_json -from ray.tune.tune import make_scheduler, run_experiments +from ray.tune.tune import _make_scheduler, run_experiments EXAMPLE_USAGE = """ Training example: - ./train.py --alg DQN --env CartPole-v0 + ./train.py --run DQN --env CartPole-v0 Grid search example: ./train.py -f tuned_examples/cartpole-grid-search-example.yaml @@ -29,16 +29,24 @@ parser = make_parser( epilog=EXAMPLE_USAGE) # See also the base parser definition in ray/tune/config_parser.py -parser.add_argument("--redis-address", default=None, type=str, - help="The Redis address of the cluster.") -parser.add_argument("--num-cpus", default=None, type=int, - help="Number of CPUs to allocate to Ray.") -parser.add_argument("--num-gpus", default=None, type=int, - help="Number of GPUs to allocate to Ray.") -parser.add_argument("--experiment-name", default="default", type=str, - help="Name of experiment dir.") -parser.add_argument("-f", "--config-file", default=None, type=str, - help="If specified, use config options from this file.") +parser.add_argument( + "--redis-address", default=None, type=str, + help="The Redis address of the cluster.") +parser.add_argument( + "--num-cpus", default=None, type=int, + help="Number of CPUs to allocate to Ray.") +parser.add_argument( + "--num-gpus", default=None, type=int, + help="Number of GPUs to allocate to Ray.") +parser.add_argument( + "--experiment-name", default="default", type=str, + help="Name of the subdirectory under `local_dir` to put results in.") +parser.add_argument( + "--env", default=None, type=str, help="The gym environment to use.") +parser.add_argument( + "-f", "--config-file", default=None, type=str, + help="If specified, use config options from this file. Note that this " + "overrides any trial-specific options set via flags above.") if __name__ == "__main__": @@ -50,13 +58,12 @@ if __name__ == "__main__": # Note: keep this in sync with tune/config_parser.py experiments = { args.experiment_name: { # i.e. log to /tmp/ray/default - "alg": args.alg, + "run": args.run, "checkpoint_freq": args.checkpoint_freq, "local_dir": args.local_dir, - "env": args.env, "resources": resources_to_json(args.resources), "stop": args.stop, - "config": args.config, + "config": dict(args.config, env=args.env), "restore": args.restore, "repeat": args.repeat, "upload_dir": args.upload_dir, @@ -64,12 +71,12 @@ if __name__ == "__main__": } for exp in experiments.values(): - if not exp.get("alg"): - parser.error("the following arguments are required: --alg") - if not exp.get("env"): + if not exp.get("run"): + parser.error("the following arguments are required: --run") + if not exp.get("env") and not exp.get("config", {}).get("env"): parser.error("the following arguments are required: --env") run_experiments( - experiments, scheduler=make_scheduler(args), + experiments, scheduler=_make_scheduler(args), redis_address=args.redis_address, num_cpus=args.num_cpus, num_gpus=args.num_gpus) diff --git a/python/ray/rllib/tuned_examples/cartpole-grid-search-example.yaml b/python/ray/rllib/tuned_examples/cartpole-grid-search-example.yaml index 0da30d41a..c5033c712 100644 --- a/python/ray/rllib/tuned_examples/cartpole-grid-search-example.yaml +++ b/python/ray/rllib/tuned_examples/cartpole-grid-search-example.yaml @@ -1,6 +1,6 @@ cartpole-ppo: env: CartPole-v0 - alg: PPO + run: PPO stop: episode_reward_mean: 200 time_total_s: 180 diff --git a/python/ray/rllib/tuned_examples/hopper-ppo.yaml b/python/ray/rllib/tuned_examples/hopper-ppo.yaml index 8f33d283e..b256a119d 100644 --- a/python/ray/rllib/tuned_examples/hopper-ppo.yaml +++ b/python/ray/rllib/tuned_examples/hopper-ppo.yaml @@ -1,6 +1,6 @@ hopper-ppo: env: Hopper-v1 - alg: PPO + run: PPO resources: cpu: 64 gpu: 4 diff --git a/python/ray/rllib/tuned_examples/humanoid-es.yaml b/python/ray/rllib/tuned_examples/humanoid-es.yaml index 0b824b36d..a3cda3ca7 100644 --- a/python/ray/rllib/tuned_examples/humanoid-es.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-es.yaml @@ -1,6 +1,6 @@ humanoid-es: env: Humanoid-v1 - alg: ES + run: ES resources: cpu: 100 driver_cpu_limit: 4 diff --git a/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml b/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml index c9db3094a..b7ce6c1cc 100644 --- a/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml @@ -1,6 +1,6 @@ humanoid-ppo-gae: env: Humanoid-v1 - alg: PPO + run: PPO stop: episode_reward_mean: 6000 resources: diff --git a/python/ray/rllib/tuned_examples/humanoid-ppo.yaml b/python/ray/rllib/tuned_examples/humanoid-ppo.yaml index 844ae1e19..c58f96bca 100644 --- a/python/ray/rllib/tuned_examples/humanoid-ppo.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-ppo.yaml @@ -1,6 +1,6 @@ humanoid-ppo: env: Humanoid-v1 - alg: PPO + run: PPO stop: episode_reward_mean: 6000 resources: diff --git a/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml b/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml index fa2d168c4..a6cb718db 100644 --- a/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml +++ b/python/ray/rllib/tuned_examples/hyperband-cartpole.yaml @@ -1,6 +1,6 @@ cartpole-ppo: env: CartPole-v0 - alg: PPO + run: PPO repeat: 3 stop: episode_reward_mean: 200 diff --git a/python/ray/rllib/tuned_examples/pong-a3c.yaml b/python/ray/rllib/tuned_examples/pong-a3c.yaml index ecf488420..03dafe6de 100644 --- a/python/ray/rllib/tuned_examples/pong-a3c.yaml +++ b/python/ray/rllib/tuned_examples/pong-a3c.yaml @@ -1,6 +1,6 @@ pong-a3c: env: PongDeterministic-v4 - alg: A3C + run: A3C resources: cpu: 16 driver_cpu_limit: 1 diff --git a/python/ray/rllib/tuned_examples/pong-dqn.yaml b/python/ray/rllib/tuned_examples/pong-dqn.yaml index a8f965837..849fe0430 100644 --- a/python/ray/rllib/tuned_examples/pong-dqn.yaml +++ b/python/ray/rllib/tuned_examples/pong-dqn.yaml @@ -1,6 +1,6 @@ pong-deterministic-dqn: env: PongDeterministic-v4 - alg: DQN + run: DQN resources: cpu: 1 gpu: 1 @@ -28,7 +28,7 @@ pong-deterministic-dqn: ] pong-noframeskip-dqn: env: PongNoFrameskip-v4 - alg: DQN + run: DQN resources: cpu: 1 gpu: 1 diff --git a/python/ray/rllib/tuned_examples/walker2d-ppo.yaml b/python/ray/rllib/tuned_examples/walker2d-ppo.yaml index e27de0d7e..4f712a79a 100644 --- a/python/ray/rllib/tuned_examples/walker2d-ppo.yaml +++ b/python/ray/rllib/tuned_examples/walker2d-ppo.yaml @@ -1,6 +1,6 @@ walker2d-v1-ppo: env: Walker2d-v1 - alg: PPO + run: PPO resources: cpu: 64 gpu: 4 diff --git a/python/ray/tune/README.rst b/python/ray/tune/README.rst index e62092a65..58477bddb 100644 --- a/python/ray/tune/README.rst +++ b/python/ray/tune/README.rst @@ -86,7 +86,7 @@ expression. cartpole-ppo: env: CartPole-v0 - alg: PPO + run: PPO repeat: 2 stop: episode_reward_mean: 200 @@ -119,7 +119,7 @@ When using the Python API, the above is equivalent to the following program: spec = { "env": "CartPole-v0", - "alg": "PPO", + "run": "PPO", "repeat": 2, "stop": { "episode_reward_mean": 200, @@ -166,9 +166,9 @@ Using ray.tune with Ray RLlib Another way to use ray.tune is through RLlib's ``python/ray/rllib/train.py`` script. This script allows you to select between different RL algorithms with -the ``--alg`` option. For example, to train pong with the A3C algorithm, run: +the ``--run`` option. For example, to train pong with the A3C algorithm, run: -- ``./train.py --env=PongDeterministic-v4 --alg=A3C --stop '{"time_total_s": 3200}' --resources '{"cpu": 8}' --config '{"num_workers": 8}'`` +- ``./train.py --env=PongDeterministic-v4 --run=A3C --stop '{"time_total_s": 3200}' --resources '{"cpu": 8}' --config '{"num_workers": 8}'`` or diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index e69de29bb..295cc3c76 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.tune.error import TuneError +from ray.tune.tune import run_experiments +from ray.tune.registry import register_env, register_trainable +from ray.tune.result import TrainingResult +from ray.tune.script_runner import ScriptRunner +from ray.tune.trainable import Trainable +from ray.tune.variant_generator import grid_search + + +register_trainable("script", ScriptRunner) + +__all__ = [ + "Trainable", + "TrainingResult", + "TuneError", + "grid_search", + "register_env", + "register_trainable", + "run_experiments", +] diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index a70309b02..484b47362 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -6,12 +6,18 @@ from __future__ import print_function import argparse import json +from ray.tune import TuneError from ray.tune.trial import Resources def json_to_resources(data): if type(data) is str: data = json.loads(data) + for k in data: + if k not in Resources._fields: + raise TuneError( + "Unknown resource type {}, must be one of {}".format( + k, Resources._fields)) return Resources( data.get("cpu", 0), data.get("gpu", 0), data.get("driver_cpu_limit"), data.get("driver_gpu_limit")) @@ -32,34 +38,49 @@ def make_parser(**kwargs): parser = argparse.ArgumentParser(**kwargs) # Note: keep this in sync with rllib/train.py - parser.add_argument("--alg", default=None, type=str, - help="The learning algorithm to train.") - parser.add_argument("--stop", default="{}", type=json.loads, - help="The stopping criteria, specified in JSON.") - parser.add_argument("--config", default="{}", type=json.loads, - help="The config of the algorithm, specified in JSON.") - parser.add_argument("--resources", default='{"cpu": 1}', - type=json_to_resources, - help="Amount of resources to allocate per trial.") - parser.add_argument("--repeat", default=1, type=int, - help="Number of times to repeat each trial.") - parser.add_argument("--local-dir", default="/tmp/ray", type=str, - help="Local dir to save training results to.") - parser.add_argument("--upload-dir", default="", type=str, - help="URI to upload training results to.") - parser.add_argument("--checkpoint-freq", default=0, type=int, - help="How many iterations between checkpoints.") - parser.add_argument("--scheduler", default="FIFO", type=str, - help="FIFO, MedianStopping, or HyperBand") - parser.add_argument("--scheduler-config", default="{}", type=json.loads, - help="Config options to pass to the scheduler.") + parser.add_argument( + "--run", default=None, type=str, + help="The algorithm or model to train. This may refer to the name " + "of a built-on algorithm (e.g. RLLib's DQN or PPO), or a " + "user-defined trainable function or class registered in the " + "tune registry.") + parser.add_argument( + "--stop", default="{}", type=json.loads, + help="The stopping criteria, specified in JSON. The keys may be any " + "field in TrainingResult, e.g. " + "'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop " + "after 600 seconds or 100k timesteps, whichever is reached first.") + parser.add_argument( + "--config", default="{}", type=json.loads, + help="Algorithm-specific configuration (e.g. env, hyperparams), " + "specified in JSON.") + parser.add_argument( + "--resources", default='{"cpu": 1}', type=json_to_resources, + help="Machine resources to allocate per trial, e.g. " + "'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned " + "unless you specify them here.") + parser.add_argument( + "--repeat", default=1, type=int, + help="Number of times to repeat each trial.") + parser.add_argument( + "--local-dir", default="/tmp/ray", type=str, + help="Local dir to save training results to. Defaults to '/tmp/ray'.") + parser.add_argument( + "--upload-dir", default="", type=str, + help="Optional URI to upload training results to.") + parser.add_argument( + "--checkpoint-freq", default=0, type=int, + help="How many training iterations between checkpoints. " + "A value of 0 (default) disables checkpointing.") + parser.add_argument( + "--scheduler", default="FIFO", type=str, + help="FIFO (default), MedianStopping, or HyperBand.") + parser.add_argument( + "--scheduler-config", default="{}", type=json.loads, + help="Config options to pass to the scheduler.") # Note: this currently only makes sense when running a single trial parser.add_argument("--restore", default=None, type=str, help="If specified, restore from this checkpoint.") - # TODO(ekl) environments are RL specific - parser.add_argument("--env", default=None, type=str, - help="The gym environment to use.") - return parser diff --git a/python/ray/tune/error.py b/python/ray/tune/error.py new file mode 100644 index 000000000..badf60a08 --- /dev/null +++ b/python/ray/tune/error.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class TuneError(Exception): + """General error class raised by ray.tune.""" + pass diff --git a/python/ray/tune/examples/tune_mnist_ray.py b/python/ray/tune/examples/tune_mnist_ray.py index 0aeb5706f..1efcceeab 100755 --- a/python/ray/tune/examples/tune_mnist_ray.py +++ b/python/ray/tune/examples/tune_mnist_ray.py @@ -31,12 +31,9 @@ from __future__ import print_function import argparse import sys import tempfile -import os +import time -import ray -from ray.tune.result import TrainingResult -from ray.tune.trial_runner import TrialRunner -from ray.tune.variant_generator import grid_search, generate_trials +from ray.tune import grid_search, run_experiments, register_trainable from tensorflow.examples.tutorials.mnist import input_data @@ -135,7 +132,12 @@ def bias_variable(shape): def main(_): # Import data - mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + for _ in range(10): + try: + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + break + except Exception: + time.sleep(5) # Create the model x = tf.placeholder(tf.float32, [None, 784]) @@ -174,9 +176,8 @@ def main(_): # !!! Report status to ray.tune !!! if status_reporter: - status_reporter.report(TrainingResult( - timesteps_total=i, - mean_accuracy=train_accuracy)) + status_reporter( + timesteps_total=i, mean_accuracy=train_accuracy) print('step %d, training accuracy %g' % (i, train_accuracy)) train_step.run( @@ -201,34 +202,24 @@ def train(config={'activation': 'relu'}, reporter=None): # !!! Example of using the ray.tune Python API !!! if __name__ == '__main__': - runner = TrialRunner() + parser = argparse.ArgumentParser() + parser.add_argument( + '--fast', action='store_true', help='Finish quickly for testing') + args, _ = parser.parse_known_args() - spec = { + register_trainable('train_mnist', train) + mnist_spec = { + 'run': 'train_mnist', 'stop': { 'mean_accuracy': 0.99, 'time_total_s': 600, }, 'config': { - 'script_file_path': os.path.abspath(__file__), - 'script_min_iter_time_s': 1, 'activation': grid_search(['relu', 'elu', 'tanh']), }, } - # These arguments are only for testing purposes. - parser = argparse.ArgumentParser() - parser.add_argument('--fast', action='store_true', - help='Run minimal iterations.') - args, _ = parser.parse_known_args() - if args.fast: - spec['stop']['training_iteration'] = 2 + mnist_spec['stop']['training_iteration'] = 2 - for trial in generate_trials(spec): - runner.add_trial(trial) - - ray.init() - - while not runner.is_finished(): - runner.step() - print(runner.debug_string()) + run_experiments({'tune_mnist_test': mnist_spec}) diff --git a/python/ray/tune/examples/tune_mnist_ray.yaml b/python/ray/tune/examples/tune_mnist_ray.yaml index d80f30d53..9cf12e533 100644 --- a/python/ray/tune/examples/tune_mnist_ray.yaml +++ b/python/ray/tune/examples/tune_mnist_ray.yaml @@ -1,4 +1,5 @@ tune_mnist: + run: script repeat: 2 resources: cpu: 1 diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py new file mode 100644 index 000000000..083a0c5d3 --- /dev/null +++ b/python/ray/tune/registry.py @@ -0,0 +1,87 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from types import FunctionType + +import ray +from ray.tune import TuneError +from ray.local_scheduler import ObjectID +from ray.tune.trainable import Trainable, wrap_function + +TRAINABLE_CLASS = "trainable_class" +ENV_CREATOR = "env_creator" +KNOWN_CATEGORIES = [TRAINABLE_CLASS, ENV_CREATOR] + + +def register_trainable(name, trainable): + """Register a trainable function or class. + + Args: + name (str): Name to register. + trainable (obj): Function or tune.Trainable clsas. Functions must + take (config, status_reporter) as arguments and will be + automatically converted into a class during registration. + """ + + if isinstance(trainable, FunctionType): + trainable = wrap_function(trainable) + if not issubclass(trainable, Trainable): + raise TypeError( + "Second argument must be convertable to Trainable", trainable) + _default_registry.register(TRAINABLE_CLASS, name, trainable) + + +def register_env(name, env_creator): + """Register a custom environment for use with RLlib. + + Args: + name (str): Name to register. + env_creator (obj): Function that creates an env. + """ + + if not isinstance(env_creator, FunctionType): + raise TypeError( + "Second argument must be a function.", env_creator) + _default_registry.register(ENV_CREATOR, name, env_creator) + + +def get_registry(): + """Use this to access the registry. This requires ray to be initialized.""" + + _default_registry.flush_values_to_object_store() + + # returns a registry copy that doesn't include the hard refs + return _Registry(_default_registry._all_objects) + + +class _Registry(object): + def __init__(self, objs={}): + self._all_objects = objs + self._refs = [] # hard refs that prevent eviction of objects + + def register(self, category, key, value): + if category not in KNOWN_CATEGORIES: + raise TuneError("Unknown category {} not among {}".format( + category, KNOWN_CATEGORIES)) + self._all_objects[(category, key)] = value + + def contains(self, category, key): + return (category, key) in self._all_objects + + def get(self, category, key): + value = self._all_objects[(category, key)] + if type(value) == ObjectID: + return ray.get(value) + else: + return value + + def flush_values_to_object_store(self): + for k, v in self._all_objects.items(): + if type(v) != ObjectID: + obj = ray.put(v) + self._all_objects[k] = obj + self._refs.append(ray.get(obj)) + + +_default_registry = _Registry() diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 02f3b4655..08ef62552 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -18,6 +18,9 @@ TrainingResult = namedtuple("TrainingResult", [ # (Required) Accumulated timesteps for this entire experiment. "timesteps_total", + # (Optional) If training is finished. + "done", + # (Optional) Custom metadata to report for this iteration. "info", diff --git a/python/ray/tune/script_runner.py b/python/ray/tune/script_runner.py index 5de6b8675..c3a86ccf3 100644 --- a/python/ray/tune/script_runner.py +++ b/python/ray/tune/script_runner.py @@ -10,6 +10,8 @@ import threading import traceback from ray.rllib.agent import Agent +from ray.tune import TuneError +from ray.tune.result import TrainingResult class StatusReporter(object): @@ -17,33 +19,30 @@ class StatusReporter(object): def __init__(self): self._latest_result = None + self._last_result = None self._lock = threading.Lock() self._error = None + self._done = False - def report(self, result): + def __call__(self, **kwargs): """Report updated training status. Args: - result (TrainingResult): Latest training result status. You must + kwargs (TrainingResult): Latest training result status. You must at least define `timesteps_total`, but probably want to report some of the other metrics as well. """ with self._lock: - self._latest_result = result - - def set_error(self, error): - """Report an error. - - Args: - error (obj): Error object or string. - """ - - self._error = error + self._latest_result = self._last_result = TrainingResult(**kwargs) def _get_and_clear_status(self): if self._error: - raise Exception("Error running script: " + str(self._error)) + raise TuneError("Error running trial: " + str(self._error)) + if self._done and not self._latest_result: + if not self._last_result: + raise TuneError("Trial finished without reporting result!") + return self._last_result._replace(done=True) with self._lock: res = self._latest_result self._latest_result = None @@ -61,7 +60,7 @@ DEFAULT_CONFIG = { "script_entrypoint": "train", # batch results to at least this granularity - "script_min_iter_time_s": 5, + "script_min_iter_time_s": 1, } @@ -79,9 +78,35 @@ class _RunnerThread(threading.Thread): try: self._entrypoint(*self._entrypoint_args) except Exception as e: - self._status_reporter.set_error(e) + self._status_reporter._error = e print("Runner thread raised: {}".format(traceback.format_exc())) raise e + finally: + self._status_reporter._done = True + + +def import_function(file_path, function_name): + # strong assumption here that we're in a new process + file_path = os.path.expanduser(file_path) + sys.path.insert(0, os.path.dirname(file_path)) + if hasattr(importlib, "util"): + # Python 3.4+ + spec = importlib.util.spec_from_file_location( + "external_file", file_path) + external_file = importlib.util.module_from_spec(spec) + spec.loader.exec_module(external_file) + elif hasattr(importlib, "machinery"): + # Python 3.3 + from importlib.machinery import SourceFileLoader + external_file = SourceFileLoader( + "external_file", file_path).load_module() + else: + # Python 2.x + import imp + external_file = imp.load_source("external_file", file_path) + if not external_file: + raise TuneError("Unable to import file at {}".format(file_path)) + return getattr(external_file, function_name) class ScriptRunner(Agent): @@ -92,51 +117,41 @@ class ScriptRunner(Agent): _allow_unknown_configs = True def _init(self): - # strong assumption here that we're in a new process - file_path = os.path.expanduser(self.config["script_file_path"]) - sys.path.insert(0, os.path.dirname(file_path)) - if hasattr(importlib, "util"): - # Python 3.4+ - spec = importlib.util.spec_from_file_location( - "external_file", file_path) - external_file = importlib.util.module_from_spec(spec) - spec.loader.exec_module(external_file) - elif hasattr(importlib, "machinery"): - # Python 3.3 - from importlib.machinery import SourceFileLoader - external_file = SourceFileLoader( - "external_file", file_path).load_module() - else: - # Python 2.x - import imp - external_file = imp.load_source("external_file", file_path) - if not external_file: - raise Exception( - "Unable to import file at {}".format( - self.config["script_file_path"])) - entrypoint = getattr(external_file, self.config["script_entrypoint"]) + entrypoint = self._trainable_func() + if not entrypoint: + entrypoint = import_function( + self.config["script_file_path"], + self.config["script_entrypoint"]) self._status_reporter = StatusReporter() + scrubbed_config = self.config.copy() + for k in self._default_config: + del scrubbed_config[k] self._runner = _RunnerThread( - entrypoint, self.config, self._status_reporter) + entrypoint, scrubbed_config, self._status_reporter) self._start_time = time.time() self._last_reported_time = self._start_time self._last_reported_timestep = 0 self._runner.start() + # Subclasses can override this to set the trainable func + # TODO(ekl) this isn't a very clean layering, we should refactor it + def _trainable_func(self): + return None + def train(self): if not self._initialize_ok: raise ValueError( "Agent initialization failed, see previous errors") - poll_start = time.time() + now = time.time() + time.sleep(self.config["script_min_iter_time_s"]) + result = self._status_reporter._get_and_clear_status() - while result is None or \ - time.time() - poll_start < \ - self.config["script_min_iter_time_s"]: + while result is None: time.sleep(1) result = self._status_reporter._get_and_clear_status() - - now = time.time() + if result.timesteps_total is None: + raise TuneError("Must specify timesteps_total in result", result) # Include the negative loss to use as a stopping condition if result.mean_loss is not None: @@ -164,6 +179,5 @@ class ScriptRunner(Agent): return result - def stop(self): + def _stop(self): self._status_reporter._stop() - Agent.stop(self) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py new file mode 100644 index 000000000..66c659449 --- /dev/null +++ b/python/ray/tune/trainable.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class Trainable(object): + """Interface for trainable models, functions, etc. + + Implementing this interface is required to use ray.tune's full + functionality, though you can also get away with supplying just a + `my_train(config, reporter)` function and calling: + + register_trainable("my_func", train) + + to register it for use with tune. The function will be automatically + converted to this interface (sans checkpoint functionality).""" + + def train(self): + """Runs one logical iteration of training. + + Returns: + A TrainingResult that describes training progress. + """ + + raise NotImplementedError + + def save(self): + """Saves the current model state to a checkpoint. + + Returns: + Checkpoint path that may be passed to restore(). + """ + + raise NotImplementedError + + def restore(self, checkpoint_path): + """Restores training state from a given model checkpoint. + + These checkpoints are returned from calls to save(). + """ + + raise NotImplementedError + + def stop(self): + """Releases all resources used by this class.""" + + pass + + +def wrap_function(train_func): + from ray.tune.script_runner import ScriptRunner + + class WrappedFunc(ScriptRunner): + def _trainable_func(self): + return train_func + + return WrappedFunc diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index c40981636..02661be1f 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,8 +8,10 @@ import ray import os from collections import namedtuple -from ray.rllib.agent import get_agent_class +from ray.tune import TuneError from ray.tune.logger import NoopLogger, UnifiedLogger +from ray.tune.result import TrainingResult +from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS class Resources( @@ -60,7 +62,7 @@ class Trial(object): ERROR = "ERROR" def __init__( - self, env_creator, alg, config={}, local_dir='/tmp/ray', + self, trainable_name, config={}, local_dir='/tmp/ray', experiment_tag=None, resources=Resources(cpu=1, gpu=0), stopping_criterion={}, checkpoint_freq=0, restore_path=None, upload_dir=None): @@ -70,16 +72,18 @@ class Trial(object): in ray.tune.config_parser. """ + if not _default_registry.contains( + TRAINABLE_CLASS, trainable_name): + raise TuneError("Unknown trainable: " + trainable_name) + + for k in stopping_criterion: + if k not in TrainingResult._fields: + raise TuneError( + "Stopping condition key `{}` must be one of {}".format( + k, TrainingResult._fields)) + # Immutable config - self.env_creator = env_creator - if type(env_creator) is str: - self.env_name = env_creator - else: - if hasattr(env_creator, "env_name"): - self.env_name = env_creator.env_name - else: - self.env_name = "custom" - self.alg = alg + self.trainable_name = trainable_name self.config = config self.local_dir = local_dir self.experiment_tag = experiment_tag @@ -92,7 +96,7 @@ class Trial(object): self.last_result = None self._checkpoint_path = restore_path self._checkpoint_obj = None - self.agent = None + self.runner = None self.status = Trial.PENDING self.location = None self.logdir = None @@ -105,7 +109,7 @@ class Trial(object): be thrown. """ - self._setup_agent() + self._setup_runner() if self._checkpoint_path: self.restore_from_path(self._checkpoint_path) elif self._checkpoint_obj: @@ -128,11 +132,11 @@ class Trial(object): self.status = Trial.TERMINATED try: - if self.agent: + if self.runner: stop_tasks = [] - stop_tasks.append(self.agent.stop.remote()) - stop_tasks.append(self.agent.__ray_terminate__.remote( - self.agent._ray_actor_id.id())) + stop_tasks.append(self.runner.stop.remote()) + stop_tasks.append(self.runner.__ray_terminate__.remote( + self.runner._ray_actor_id.id())) # TODO(ekl) seems like wait hangs when killing actors _, unfinished = ray.wait( stop_tasks, num_returns=2, timeout=250) @@ -140,10 +144,10 @@ class Trial(object): print(("Stopping %s Actor timed out, " "but moving on...") % self) except Exception: - print("Error stopping agent:", traceback.format_exc()) + print("Error stopping runner:", traceback.format_exc()) self.status = Trial.ERROR finally: - self.agent = None + self.runner = None if stop_logger and self.result_logger: self.result_logger.close() @@ -159,7 +163,7 @@ class Trial(object): self.stop(stop_logger=False) self.status = Trial.PAUSED except Exception: - print("Error pausing agent:", traceback.format_exc()) + print("Error pausing runner:", traceback.format_exc()) self.status = Trial.ERROR def unpause(self): @@ -177,11 +181,14 @@ class Trial(object): """Returns Ray future for one iteration of training.""" assert self.status == Trial.RUNNING, self.status - return self.agent.train.remote() + return self.runner.train.remote() def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" + if result.done: + return True + for criteria, stop_value in self.stopping_criterion.items(): if getattr(result, criteria) >= stop_value: return True @@ -240,9 +247,9 @@ class Trial(object): obj = None path = None if to_object_store: - obj = self.agent.save_to_object.remote() + obj = self.runner.save_to_object.remote() else: - path = ray.get(self.agent.save.remote()) + path = ray.get(self.runner.save.remote()) self._checkpoint_path = path self._checkpoint_obj = obj @@ -250,39 +257,40 @@ class Trial(object): return path or obj def restore_from_path(self, path): - """Restores agent state from specified path. + """Restores runner state from specified path. Args: path (str): A path where state will be restored. """ - if self.agent is None: - print("Unable to restore - no agent") + if self.runner is None: + print("Unable to restore - no runner") else: try: - ray.get(self.agent.restore.remote(path)) + ray.get(self.runner.restore.remote(path)) except Exception: - print("Error restoring agent:", traceback.format_exc()) + print("Error restoring runner:", traceback.format_exc()) self.status = Trial.ERROR def restore_from_obj(self, obj): - """Restores agent state from the specified object.""" + """Restores runner state from the specified object.""" - if self.agent is None: - print("Unable to restore - no agent") + if self.runner is None: + print("Unable to restore - no runner") else: try: - ray.get(self.agent.restore_from_object.remote(obj)) + ray.get(self.runner.restore_from_object.remote(obj)) except Exception: - print("Error restoring agent:", traceback.format_exc()) + print("Error restoring runner:", traceback.format_exc()) self.status = Trial.ERROR - def _setup_agent(self): + def _setup_runner(self): self.status = Trial.RUNNING - agent_cls = get_agent_class(self.alg) + trainable_cls = get_registry().get( + TRAINABLE_CLASS, self.trainable_name) cls = ray.remote( num_cpus=self.resources.driver_cpu_limit, - num_gpus=self.resources.driver_gpu_limit)(agent_cls) + num_gpus=self.resources.driver_gpu_limit)(trainable_cls) if not self.result_logger: if not os.path.exists(self.local_dir): os.makedirs(self.local_dir) @@ -292,19 +300,18 @@ class Trial(object): self.config, self.logdir, self.upload_dir) remote_logdir = self.logdir # Logging for trials is handled centrally by TrialRunner, so - # configure the remote agent to use a noop-logger. - self.agent = cls.remote( - self.env_creator, self.config, - lambda config: NoopLogger(config, remote_logdir)) + # configure the remote runner to use a noop-logger. + self.runner = cls.remote( + config=self.config, + registry=get_registry(), + logger_creator=lambda config: NoopLogger(config, remote_logdir)) def __str__(self): - identifier = '{}_{}'.format(self.alg, self.env_name) + if "env" in self.config: + identifier = "{}_{}".format( + self.trainable_name, self.config["env"]) + else: + identifier = self.trainable_name if self.experiment_tag: - identifier += '_' + self.experiment_tag + identifier += "_" + self.experiment_tag return identifier - - def __eq__(self, other): - return str(self) == str(other) - - def __hash__(self): - return hash(str(self)) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 8eef585e9..c8f5ad912 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -7,6 +7,7 @@ import ray import time import traceback +from ray.tune import TuneError from ray.tune.trial import Trial, Resources from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler @@ -27,7 +28,7 @@ class TrialRunner(object): While Ray itself provides resource management for tasks and actors, this is not sufficient when scheduling trials that may instantiate multiple actors. - This is because if insufficient resources are available, concurrent agents + This is because if insufficient resources are available, concurrent trials could deadlock waiting for new resources to become available. Furthermore, oversubscribing the cluster could degrade training performance, leading to misleading benchmark results. @@ -77,13 +78,15 @@ class TrialRunner(object): else: for trial in self._trials: if trial.status == Trial.PENDING: - assert self.has_resources(trial.resources), \ - ("Insufficient cluster resources to launch trial", - (trial.resources, self._avail_resources)) + if not self.has_resources(trial.resources): + raise TuneError( + "Insufficient cluster resources to launch trial", + (trial.resources, self._avail_resources)) elif trial.status == Trial.PAUSED: - assert False, "There are paused trials, but no more "\ - "pending trials with sufficient resources." - assert False, "Called step when all trials finished?" + raise TuneError( + "There are paused trials, but no more pending " + "trials with sufficient resources.") + raise TuneError("Called step when all trials finished?") def get_trials(self): """Returns the list of trials managed by this TrialRunner. @@ -141,14 +144,14 @@ class TrialRunner(object): trial.start() self._running[trial.train_remote()] = trial except Exception: - print("Error starting agent, retrying:", traceback.format_exc()) + print("Error starting runner, retrying:", traceback.format_exc()) time.sleep(2) trial.stop(error=True) try: trial.start() self._running[trial.train_remote()] = trial except Exception: - print("Error starting agent, abort:", traceback.format_exc()) + print("Error starting runner, abort:", traceback.format_exc()) trial.stop(error=True) # note that we don't return the resources, since they may # have been lost diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 0a2fcaa24..acff5a264 100755 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -7,9 +7,10 @@ from __future__ import print_function import argparse import json import sys -import yaml import ray + +from ray.tune import TuneError from ray.tune.hyperband import HyperBandScheduler from ray.tune.median_stopping_rule import MedianStoppingRule from ray.tune.trial import Trial @@ -44,24 +45,25 @@ parser.add_argument("-f", "--config-file", required=True, type=str, help="Read experiment options from this JSON/YAML file.") -SCHEDULERS = { +_SCHEDULERS = { "FIFO": FIFOScheduler, "MedianStopping": MedianStoppingRule, "HyperBand": HyperBandScheduler, } -def make_scheduler(args): - if args.scheduler in SCHEDULERS: - return SCHEDULERS[args.scheduler](**args.scheduler_config) +def _make_scheduler(args): + if args.scheduler in _SCHEDULERS: + return _SCHEDULERS[args.scheduler](**args.scheduler_config) else: - assert False, "Unknown scheduler: {}, should be one of {}".format( - args.scheduler, SCHEDULERS.keys()) + raise TuneError( + "Unknown scheduler: {}, should be one of {}".format( + args.scheduler, _SCHEDULERS.keys())) def run_experiments(experiments, scheduler=None, **ray_args): if scheduler is None: - scheduler = make_scheduler(args) + scheduler = FIFOScheduler() runner = TrialRunner(scheduler) for name, spec in experiments.items(): @@ -77,16 +79,16 @@ def run_experiments(experiments, scheduler=None, **ray_args): for trial in runner.get_trials(): if trial.status != Trial.TERMINATED: - print("Exit 1") - sys.exit(1) + raise TuneError("Trial did not complete", trial) - print("Exit 0") + return runner.get_trials() if __name__ == "__main__": + import yaml args = parser.parse_args(sys.argv[1:]) with open(args.config_file) as f: experiments = yaml.load(f) run_experiments( - experiments, make_scheduler(args), redis_address=args.redis_address, + experiments, _make_scheduler(args), redis_address=args.redis_address, num_cpus=args.num_cpus, num_gpus=args.num_gpus) diff --git a/python/ray/tune/variant_generator.py b/python/ray/tune/variant_generator.py index 434b50ed3..e6ea15fe6 100644 --- a/python/ray/tune/variant_generator.py +++ b/python/ray/tune/variant_generator.py @@ -5,6 +5,7 @@ import os import random import types +from ray.tune import TuneError from ray.tune.trial import Trial from ray.tune.config_parser import make_parser, json_to_resources @@ -20,6 +21,9 @@ def generate_trials(unresolved_spec, output_path=''): output_path (str): Path where to store experiment outputs. """ + if "run" not in unresolved_spec: + raise TuneError("Must specify `run` in {}".format(unresolved_spec)) + def to_argv(config): argv = [] for k, v in config.items(): @@ -34,15 +38,23 @@ def generate_trials(unresolved_spec, output_path=''): i = 0 for _ in range(unresolved_spec.get("repeat", 1)): for resolved_vars, spec in generate_variants(unresolved_spec): - args = parser.parse_args(to_argv(spec)) + try: + # Special case the `env` param for RLlib by automatically + # moving it into the `config` section. + if "env" in spec: + spec["config"] = spec.get("config", {}) + spec["config"]["env"] = spec["env"] + del spec["env"] + args = parser.parse_args(to_argv(spec)) + except SystemExit: + raise TuneError("Error parsing args, see above message", spec) if resolved_vars: experiment_tag = "{}_{}".format(i, resolved_vars) else: experiment_tag = str(i) i += 1 yield Trial( - env_creator=spec.get("env", lambda: None), - alg=spec.get("alg", "script"), + trainable_name=spec["run"], config=spec.get("config", {}), local_dir=os.path.join(args.local_dir, output_path), experiment_tag=experiment_tag, @@ -105,8 +117,8 @@ _MAX_RESOLUTION_PASSES = 20 def _format_vars(resolved_vars): out = [] for path, value in sorted(resolved_vars.items()): - if path[0] in ["alg", "env", "resources"]: - continue # these settings aren't usually search parameters + if path[0] in ["run", "env", "resources"]: + continue # TrialRunner already has these in the experiment_tag pieces = [] last_string = True for k in path[::-1]: @@ -229,9 +241,10 @@ def _try_resolve(v): elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v: # Grid search values grid_values = v["grid_search"] - assert isinstance(grid_values, list), \ - "Grid search expected list of values, got: {}".format( - grid_values) + if not isinstance(grid_values, list): + raise TuneError( + "Grid search expected list of values, got: {}".format( + grid_values)) return False, grid_values return True, v diff --git a/python/ray/tune/visual_utils.py b/python/ray/tune/visual_utils.py index 833079942..561097519 100644 --- a/python/ray/tune/visual_utils.py +++ b/python/ray/tune/visual_utils.py @@ -35,7 +35,7 @@ def _parse_results(res_path): pass res_dict = _flatten_dict(json.loads(line.strip())) except Exception as e: - print("Importing %s failed...Perhaps empty?" % res_path) + print("Importing %s failed...Perhaps empty?" % res_path, e) return res_dict @@ -60,7 +60,7 @@ def _resolve(directory, result_fname): def load_results_to_df(directory, result_name="result.json"): exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory) for f in files if f == result_name] - data = [_resolve(directory, result_name) for directory in exp_directories] + data = [_resolve(d, result_name) for d in exp_directories] return pd.DataFrame(data) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index b382ffda3..e3454a90b 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -59,83 +59,83 @@ python $ROOT_DIR/multi_node_docker_test.py \ docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v0 \ - --alg A3C \ + --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 16}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ - --alg PPO \ + --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"free_log_std": true}}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ - --alg PPO \ + --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "use_gae": false}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ - --alg ES \ + --run ES \ --stop '{"training_iteration": 2}' \ --config '{"stepsize": 0.01, "episodes_per_batch": 20, "timesteps_per_batch": 100}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-v0 \ - --alg ES \ + --run ES \ --stop '{"training_iteration": 2}' \ --config '{"stepsize": 0.01, "episodes_per_batch": 20, "timesteps_per_batch": 100}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ - --alg A3C \ + --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"use_lstm": false}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ - --alg DQN \ + --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ - --alg DQN \ + --run DQN \ --stop '{"training_iteration": 2}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ - --alg PPO \ + --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"num_sgd_iter": 10, "sgd_batchsize": 64, "timesteps_per_batch": 1000, "num_workers": 1}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ - --alg DQN \ + --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env MontezumaRevenge-v0 \ - --alg PPO \ + --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}' docker run --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ - --alg A3C \ + --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}' diff --git a/test/trial_runner_test.py b/test/trial_runner_test.py index 021005489..1a7c56a1a 100644 --- a/test/trial_runner_test.py +++ b/test/trial_runner_test.py @@ -2,38 +2,201 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest import os +import time +import unittest import ray +from ray.rllib import _register_all + +from ray.tune import Trainable, TuneError +from ray.tune import register_env, register_trainable, run_experiments +from ray.tune.registry import _default_registry, TRAINABLE_CLASS from ray.tune.trial import Trial, Resources from ray.tune.trial_runner import TrialRunner from ray.tune.variant_generator import generate_trials, grid_search, \ RecursiveDependencyError +class TrainableFunctionApiTest(unittest.TestCase): + def tearDown(self): + ray.worker.cleanup() + _register_all() # re-register the evicted objects + + def testRegisterEnv(self): + register_env("foo", lambda: None) + self.assertRaises(TypeError, lambda: register_env("foo", 2)) + + def testRegisterTrainable(self): + def train(config, reporter): + pass + + class A(object): + pass + + class B(Trainable): + pass + + register_trainable("foo", train) + register_trainable("foo", B) + self.assertRaises(TypeError, lambda: register_trainable("foo", B())) + self.assertRaises(TypeError, lambda: register_trainable("foo", A)) + + def testRewriteEnv(self): + def train(config, reporter): + reporter(timesteps_total=1) + register_trainable("f1", train) + + [trial] = run_experiments({"foo": { + "run": "f1", + "env": "CartPole-v0", + }}) + self.assertEqual(trial.config["env"], "CartPole-v0") + + def testConfigPurity(self): + def train(config, reporter): + assert config == {"a": "b"}, config + reporter(timesteps_total=1) + register_trainable("f1", train) + run_experiments({"foo": { + "run": "f1", + "config": {"a": "b"}, + }}) + + def testBadParams(self): + def f(): + run_experiments({"foo": {}}) + self.assertRaises(TuneError, f) + + def testBadParams2(self): + def f(): + run_experiments({"foo": { + "bah": "this param is not allowed", + }}) + self.assertRaises(TuneError, f) + + def testBadParams3(self): + def f(): + run_experiments({"foo": { + "run": grid_search("invalid grid search"), + }}) + self.assertRaises(TuneError, f) + + def testBadParams4(self): + def f(): + run_experiments({"foo": { + "run": "asdf", + }}) + self.assertRaises(TuneError, f) + + def testBadParams5(self): + def f(): + run_experiments({"foo": { + "run": "PPO", + "stop": {"asdf": 1} + }}) + self.assertRaises(TuneError, f) + + def testBadParams6(self): + def f(): + run_experiments({"foo": { + "run": "PPO", + "resources": {"asdf": 1} + }}) + self.assertRaises(TuneError, f) + + def testBadReturn(self): + def train(config, reporter): + reporter() + register_trainable("f1", train) + + def f(): + run_experiments({"foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + }}) + self.assertRaises(TuneError, f) + + def testEarlyReturn(self): + def train(config, reporter): + reporter(timesteps_total=100, done=True) + time.sleep(99999) + register_trainable("f1", train) + [trial] = run_experiments({"foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + }}) + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result.timesteps_total, 100) + + def testAbruptReturn(self): + def train(config, reporter): + reporter(timesteps_total=100) + register_trainable("f1", train) + [trial] = run_experiments({"foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + }}) + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result.timesteps_total, 100) + + def testErrorReturn(self): + def train(config, reporter): + raise Exception("uh oh") + register_trainable("f1", train) + + def f(): + run_experiments({"foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + }}) + self.assertRaises(TuneError, f) + + def testSuccess(self): + def train(config, reporter): + for i in range(100): + reporter(timesteps_total=i) + register_trainable("f1", train) + [trial] = run_experiments({"foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + }}) + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result.timesteps_total, 99) + + class VariantGeneratorTest(unittest.TestCase): def testParseToTrials(self): trials = generate_trials({ - "env": "Pong-v0", - "alg": "PPO", + "run": "PPO", "repeat": 2, "config": { + "env": "Pong-v0", "foo": "bar" }, }, "tune-pong") trials = list(trials) self.assertEqual(len(trials), 2) - self.assertEqual(trials[0].env_name, "Pong-v0") - self.assertEqual(trials[0].config, {"foo": "bar"}) - self.assertEqual(trials[0].alg, "PPO") + self.assertEqual(str(trials[0]), "PPO_Pong-v0_0") + self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"}) + self.assertEqual(trials[0].trainable_name, "PPO") self.assertEqual(trials[0].experiment_tag, "0") self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong") self.assertEqual(trials[1].experiment_tag, "1") def testEval(self): trials = generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "foo": { "eval": "2 + 2" @@ -48,7 +211,7 @@ class VariantGeneratorTest(unittest.TestCase): def testGridSearch(self): trials = generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "bar": { "grid_search": [True, False] @@ -71,7 +234,7 @@ class VariantGeneratorTest(unittest.TestCase): def testGridSearchAndEval(self): trials = generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "qux": lambda spec: 2 + 2, "bar": grid_search([True, False]), @@ -85,7 +248,7 @@ class VariantGeneratorTest(unittest.TestCase): def testConditionResolution(self): trials = generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "x": 1, "y": lambda spec: spec.config.x + 1, @@ -98,7 +261,7 @@ class VariantGeneratorTest(unittest.TestCase): def testDependentLambda(self): trials = generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "x": grid_search([1, 2]), "y": lambda spec: spec.config.x * 100, @@ -111,7 +274,7 @@ class VariantGeneratorTest(unittest.TestCase): def testDependentGridSearch(self): trials = generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "x": grid_search([ lambda spec: spec.config.y * 100, @@ -128,7 +291,7 @@ class VariantGeneratorTest(unittest.TestCase): def testRecursiveDep(self): try: list(generate_trials({ - "env": "Pong-v0", + "run": "PPO", "config": { "foo": lambda spec: spec.config.foo, }, @@ -142,10 +305,11 @@ class VariantGeneratorTest(unittest.TestCase): class TrialRunnerTest(unittest.TestCase): def tearDown(self): ray.worker.cleanup() + _register_all() # re-register the evicted objects def testTrialStatus(self): ray.init() - trial = Trial("CartPole-v0", "__fake") + trial = Trial("__fake") self.assertEqual(trial.status, Trial.PENDING) trial.start() self.assertEqual(trial.status, Trial.RUNNING) @@ -156,11 +320,12 @@ class TrialRunnerTest(unittest.TestCase): def testTrialErrorOnStart(self): ray.init() - trial = Trial("CartPole-v0", "asdf") + _default_registry.register(TRAINABLE_CLASS, "asdf", None) + trial = Trial("asdf") try: trial.start() except Exception as e: - self.assertIn("Unknown algorithm", str(e)) + self.assertIn("a class", str(e)) def testResourceScheduler(self): ray.init(num_cpus=4, num_gpus=1) @@ -170,8 +335,8 @@ class TrialRunnerTest(unittest.TestCase): "resources": Resources(cpu=1, gpu=1), } trials = [ - Trial("CartPole-v0", "__fake", **kwargs), - Trial("CartPole-v0", "__fake", **kwargs)] + Trial("__fake", **kwargs), + Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) @@ -199,8 +364,8 @@ class TrialRunnerTest(unittest.TestCase): "resources": Resources(cpu=1, gpu=1), } trials = [ - Trial("CartPole-v0", "__fake", **kwargs), - Trial("CartPole-v0", "__fake", **kwargs)] + Trial("__fake", **kwargs), + Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) @@ -227,9 +392,10 @@ class TrialRunnerTest(unittest.TestCase): "stopping_criterion": {"training_iteration": 1}, "resources": Resources(cpu=1, gpu=1), } + _default_registry.register(TRAINABLE_CLASS, "asdf", None) trials = [ - Trial("CartPole-v0", "asdf", **kwargs), - Trial("CartPole-v0", "__fake", **kwargs)] + Trial("asdf", **kwargs), + Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) @@ -248,17 +414,17 @@ class TrialRunnerTest(unittest.TestCase): "stopping_criterion": {"training_iteration": 1}, "resources": Resources(cpu=1, gpu=1), } - runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs)) + runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) - self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1) + self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) path = trials[0].checkpoint() kwargs["restore_path"] = path - runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs)) + runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() runner.step() @@ -268,7 +434,7 @@ class TrialRunnerTest(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) self.assertEqual(trials[1].status, Trial.RUNNING) - self.assertEqual(ray.get(trials[1].agent.get_info.remote()), 1) + self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) self.addCleanup(os.remove, path) def testPauseThenResume(self): @@ -278,14 +444,14 @@ class TrialRunnerTest(unittest.TestCase): "stopping_criterion": {"training_iteration": 2}, "resources": Resources(cpu=1, gpu=1), } - runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs)) + runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) - self.assertEqual(ray.get(trials[0].agent.get_info.remote()), None) + self.assertEqual(ray.get(trials[0].runner.get_info.remote()), None) - self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1) + self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) trials[0].pause() self.assertEqual(trials[0].status, Trial.PAUSED) @@ -295,7 +461,7 @@ class TrialRunnerTest(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) - self.assertEqual(ray.get(trials[0].agent.get_info.remote()), 1) + self.assertEqual(ray.get(trials[0].runner.get_info.remote()), 1) runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) diff --git a/test/trial_scheduler_test.py b/test/trial_scheduler_test.py index 568d93f20..a31b0959b 100644 --- a/test/trial_scheduler_test.py +++ b/test/trial_scheduler_test.py @@ -20,8 +20,8 @@ def result(t, rew): class EarlyStoppingSuite(unittest.TestCase): def basicSetup(self, rule): - t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10 - t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5 + t1 = Trial("PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("PPO") # mean is 450, max 450, t_max=5 for i in range(10): self.assertEqual( rule.on_trial_result(None, t1, result(i, i * 100)), @@ -62,7 +62,7 @@ class EarlyStoppingSuite(unittest.TestCase): t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) rule.on_trial_complete(None, t2, result(10, 1000)) - t3 = Trial("t3", "PPO") + t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 10)), TrialScheduler.CONTINUE) @@ -77,7 +77,7 @@ class EarlyStoppingSuite(unittest.TestCase): rule = MedianStoppingRule(grace_period=0, min_samples_required=2) t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) - t3 = Trial("t3", "PPO") + t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.CONTINUE) @@ -91,7 +91,7 @@ class EarlyStoppingSuite(unittest.TestCase): t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) rule.on_trial_complete(None, t2, result(10, 1000)) - t3 = Trial("t3", "PPO") + t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 260)), TrialScheduler.CONTINUE) @@ -105,7 +105,7 @@ class EarlyStoppingSuite(unittest.TestCase): t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) rule.on_trial_complete(None, t2, result(10, 1000)) - t3 = Trial("t3", "PPO") + t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 260)), TrialScheduler.CONTINUE) @@ -120,8 +120,8 @@ class EarlyStoppingSuite(unittest.TestCase): rule = MedianStoppingRule( grace_period=0, min_samples_required=1, time_attr='training_iteration', reward_attr='neg_mean_loss') - t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10 - t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5 + t1 = Trial("PPO") # mean is 450, max 900, t_max=10 + t2 = Trial("PPO") # mean is 450, max 450, t_max=5 for i in range(10): self.assertEqual( rule.on_trial_result(None, t1, result2(i, i * 100)), @@ -166,7 +166,7 @@ class HyperbandSuite(unittest.TestCase): (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);""" sched = HyperBandScheduler() for i in range(num_trials): - t = Trial("t%d" % i, "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner() return sched, runner @@ -211,7 +211,7 @@ class HyperbandSuite(unittest.TestCase): def advancedSetup(self): sched = self.basicSetup() for i in range(4): - t = Trial("t%d" % (i + 20), "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) self.assertEqual(sched._cur_band_filled(), False) @@ -232,7 +232,7 @@ class HyperbandSuite(unittest.TestCase): sched = HyperBandScheduler() i = 0 while not sched._cur_band_filled(): - t = Trial("t%d" % (i), "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) i += 1 self.assertEqual(len(sched._hyperbands[0]), 5) @@ -244,7 +244,7 @@ class HyperbandSuite(unittest.TestCase): sched = HyperBandScheduler(max_t=810) i = 0 while not sched._cur_band_filled(): - t = Trial("t%d" % (i), "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) i += 1 self.assertEqual(len(sched._hyperbands[0]), 5) @@ -257,7 +257,7 @@ class HyperbandSuite(unittest.TestCase): sched = HyperBandScheduler(max_t=1) i = 0 while len(sched._hyperbands) < 2: - t = Trial("t%d" % (i), "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) i += 1 self.assertEqual(len(sched._hyperbands[0]), 5) @@ -415,7 +415,7 @@ class HyperbandSuite(unittest.TestCase): status = sched.on_trial_result( mock_runner, t, result(init_units, i)) self.assertEqual(status, TrialScheduler.CONTINUE) - t = Trial("t%d" % 100, "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) mock_runner._launch_trial(t) self.assertEqual(len(sched._state["bracket"].current_trials()), 2) @@ -440,7 +440,7 @@ class HyperbandSuite(unittest.TestCase): stats = self.default_statistics() for i in range(stats["max_trials"]): - t = Trial("t%d" % i, "__fake") + t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner()