[tune] Ray Tune API cleanup (#1454)

Remove rllib dep: trainable is now a standalone abstract class that can be easily subclassed.

Clean up hyperband: fix debug string and add an example.

Remove YAML api / ScriptRunner: this was never really used.

Move ray.init() out of run_experiments(): This provides greater flexibility and should be less confusing since there isn't an implicit init() done there. Note that this is a breaking API change for tune.
This commit is contained in:
Eric Liang 2018-01-24 16:55:17 -08:00 committed by GitHub
parent a1b01ee7fb
commit 173f1d629a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 486 additions and 421 deletions

View file

@ -262,6 +262,7 @@ in the ``config`` section of the experiments.
.. code-block:: python
import ray
from ray.tune.tune import run_experiments
from ray.tune.variant_generator import grid_search
@ -286,6 +287,7 @@ in the ``config`` section of the experiments.
# put additional experiments to run concurrently here
}
ray.init()
run_experiments(experiment)
Contributing to RLlib

View file

@ -16,8 +16,9 @@ You can find the code for Ray Tune `here on GitHub <https://github.com/ray-proje
Getting Started
---------------
::
.. code-block:: python
import ray
from ray.tune import register_trainable, grid_search, run_experiments
def my_func(config, reporter):
@ -30,6 +31,7 @@ Getting Started
register_trainable("my_func", my_func)
ray.init()
run_experiments({
"my_experiment": {
"run": "my_func",
@ -67,7 +69,7 @@ Ray Tune logs trial results to a unique directory per experiment, e.g. ``~/ray_r
To visualize learning in tensorboard, run:
::
.. code-block:: bash
$ pip install tensorboard
$ tensorboard --logdir=~/ray_results/my_experiment
@ -76,7 +78,7 @@ To visualize learning in tensorboard, run:
To use rllab's VisKit (you may have to install some dependencies), run:
::
.. code-block:: bash
$ git clone https://github.com/rll/rllab.git
$ python rllab/rllab/viskit/frontend.py ~/ray_results/my_experiment
@ -85,7 +87,7 @@ To use rllab's VisKit (you may have to install some dependencies), run:
Finally, to view the results with a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__, open `ParalleCoordinatesVisualization.ipynb <https://github.com/ray-project/ray/blob/master/python/ray/tune/ParallelCoordinatesVisualization.ipynb>`__ as follows and run its cells:
::
.. code-block:: bash
$ cd $RAY_HOME/python/ray/tune
$ jupyter-notebook ParallelCoordinatesVisualization.ipynb
@ -97,7 +99,7 @@ In the above example, we specified a grid search over two parameters using the `
The following shows grid search over two nested parameters combined with random sampling from two lambda functions. Note that the value of ``beta`` depends on the value of ``alpha``, which is represented by referencing ``spec.config.alpha`` in the lambda function. This lets you specify conditional parameter distributions.
::
.. code-block:: python
"config": {
"alpha": lambda spec: np.random.uniform(100),
@ -118,60 +120,71 @@ Early Stopping
To reduce costs, long-running trials can often be early stopped if their initial performance is not promising. Ray Tune allows early stopping algorithms to be plugged in on top of existing grid or random searches. This can be enabled by setting the ``scheduler`` parameter of ``run_experiments``, e.g.
.. code-block:: python
run_experiments({...}, scheduler=HyperBandScheduler())
An example of this can be found in `hyperband_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__. The progress of one such HyperBand run is shown below.
Note that some trial schedulers such as HyperBand require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
::
run_experiments({...}, scheduler=MedianStoppingRule())
== Status ==
Using HyperBand: num_stopped=0 total_brackets=5
Round #0:
Bracket(n=5, r=100, completed=80%): {'PAUSED': 4, 'PENDING': 1}
Bracket(n=8, r=33, completed=23%): {'PAUSED': 4, 'PENDING': 4}
Bracket(n=15, r=11, completed=4%): {'RUNNING': 2, 'PAUSED': 2, 'PENDING': 11}
Bracket(n=34, r=3, completed=0%): {'RUNNING': 2, 'PENDING': 32}
Bracket(n=81, r=1, completed=0%): {'PENDING': 38}
Resources used: 4/4 CPUs, 0/0 GPUs
Result logdir: /home/eric/ray_results/hyperband_test
PAUSED trials:
- my_class_0_height=99,width=43: PAUSED [pid=11664], 0 s, 100 ts, 97.1 rew
- my_class_11_height=85,width=81: PAUSED [pid=11771], 0 s, 33 ts, 32.8 rew
- my_class_12_height=0,width=52: PAUSED [pid=11785], 0 s, 33 ts, 0 rew
- my_class_19_height=44,width=88: PAUSED [pid=11811], 0 s, 11 ts, 5.47 rew
- my_class_27_height=96,width=84: PAUSED [pid=11840], 0 s, 11 ts, 12.5 rew
... 5 more not shown
PENDING trials:
- my_class_10_height=12,width=25: PENDING
- my_class_13_height=90,width=45: PENDING
- my_class_14_height=69,width=45: PENDING
- my_class_15_height=41,width=11: PENDING
- my_class_16_height=57,width=69: PENDING
... 81 more not shown
RUNNING trials:
- my_class_23_height=75,width=51: RUNNING [pid=11843], 0 s, 1 ts, 1.47 rew
- my_class_26_height=16,width=48: RUNNING
- my_class_31_height=40,width=10: RUNNING
- my_class_53_height=28,width=96: RUNNING
Currently we support the following early stopping algorithms, or you can write your own that implements the `TrialScheduler <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial_scheduler.py>`__ interface:
Currently we support the following early stopping algorithms, or you can write your own that implements the `TrialScheduler <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial_scheduler.py>`__ interface.
.. autoclass:: ray.tune.median_stopping_rule.MedianStoppingRule
.. autoclass:: ray.tune.hyperband.HyperBandScheduler
Checkpointing support
---------------------
Trial Checkpointing
-------------------
To enable checkpoint / resume, the full ``Trainable`` API must be implemented (though as shown in the examples above, you can get away with just supplying a ``train(config, reporter)`` func if you don't need checkpointing). Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand. For example, all `RLlib agents <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agent.py>`__ implement the ``Trainable`` API.
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand.
.. autoclass:: ray.tune.trainable.Trainable
:members:
Resource Allocation
-------------------
Ray Tune runs each trial as a Ray actor, allocating the specified GPU and CPU ``resources`` to each actor (defaulting to 1 CPU per trial). A trial will not be scheduled unless at least that amount of resources is available in the cluster, preventing the cluster from being overloaded.
If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will also want to set ``driver_cpu_limit`` or ``driver_gpu_limit`` to tell Ray not to assign the entire resource reservation to your top-level trainable function, as described in `trial.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial.py>`__.
Command-line JSON/YAML API
--------------------------
The JSON config passed to ``run_experiments`` can also be put in a JSON or YAML file, and the experiments run using the ``tune.py`` script. This supports the same functionality as the Python API, e.g.:
::
cd ray/python/tune
./tune.py -f examples/tune_mnist_ray.yaml --scheduler=MedianStoppingRule
For more examples of experiments described by YAML files, see `RLlib tuned examples <https://github.com/ray-project/ray/tree/master/python/ray/rllib/tuned_examples>`__.
Running in a large cluster
--------------------------
The ``run_experiments`` also takes any arguments that ``ray.init()`` does. This can be used to pass in the redis address of a multi-node Ray cluster. For more details, check out the `tune.py script <https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py>`__.
If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will also want to set ``driver_cpu_limit`` or ``driver_gpu_limit`` to tell Ray not to assign the entire resource reservation to your top-level trainable function, as described in `trial.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial.py>`__. For example, if a trainable class requires 1 GPU itself, but will launch 4 actors each using another GPU, then it should set ``"gpu": 5, "driver_gpu_limit": 1``.
Client API
----------
You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with a flag, either from the command-line, e.g.:
You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with ``with_server=True``:
::
cd ray/python/tune
./tune.py -f examples/tune_mnist_ray.yaml --server=True --server-port=4321
Or within the Python API, e.g.:
::
.. code-block:: python
run_experiments({...}, with_server=True, server_port=4321)

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
@ -25,6 +26,7 @@ env_config.update({
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-a3c": {
"run": "A3C",

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
@ -25,6 +26,7 @@ env_config.update({
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-dqn": {
"run": "DQN",

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
@ -25,6 +26,7 @@ env_config.update({
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-ppo": {
"run": "PPO",

View file

@ -27,6 +27,7 @@ register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
redis_address = ray.services.get_node_ip_address() + ":6379"
ray.init(redis_address=redis_address)
run_experiments({
"carla-a3c": {
"run": "A3C",
@ -50,4 +51,4 @@ run_experiments({
"num_workers": 2,
},
},
}, redis_address=redis_address)
})

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
@ -23,6 +24,7 @@ env_config.update({
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init()
run_experiments({
"carla-dqn": {
"run": "DQN",

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.tune import register_env, run_experiments
from env import CarlaEnv, ENV_CONFIG
@ -22,6 +23,7 @@ env_config.update({
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()
ray.init(redirect_output=True)
run_experiments({
"carla": {
"run": "PPO",
@ -55,4 +57,4 @@ run_experiments({
}
},
},
}, redirect_output=True)
})

View file

@ -8,6 +8,7 @@ import gym
from gym.spaces import Discrete, Box
from gym.envs.registration import EnvSpec
import ray
from ray.tune import run_experiments
from ray.tune.registry import register_env
@ -41,6 +42,7 @@ class SimpleCorridor(gym.Env):
if __name__ == "__main__":
env_creator_name = "corridor"
register_env(env_creator_name, lambda config: SimpleCorridor(config))
ray.init()
run_experiments({
"demo": {
"run": "PPO",

View file

@ -2,15 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Note: do not introduce unnecessary library dependencies here, e.g. gym
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
# This file is imported from the tune module in order to register RLlib agents.
from ray.tune.registry import register_trainable
from ray.rllib.agent import get_agent_class
def _register_all():
for key in [
"PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"]:
try:
from ray.rllib.agent import get_agent_class
register_trainable(key, get_agent_class(key))
except ImportError as e:
print("Warning: could not import {}: {}".format(key, e))

View file

@ -2,25 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import logging
import numpy as np
import io
import os
import gzip
import pickle
import shutil
import tempfile
import time
import uuid
# Note: avoid introducing unnecessary library dependencies here, e.g. gym
# until https://github.com/ray-project/ray/issues/1144 is resolved
import tensorflow as tf
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import ENV_CREATOR, get_registry
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
logger = logging.getLogger(__name__)
@ -66,7 +55,7 @@ class Agent(Trainable):
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry, for registering user-defined
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
@ -83,183 +72,43 @@ class Agent(Trainable):
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, it will be assumed empty.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
env = env or config.get("env")
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, registry, logger_creator)
def _setup(self):
env = self._env_id
if env:
config["env"] = env
if registry and registry.contains(ENV_CREATOR, env):
self.env_creator = registry.get(ENV_CREATOR, env)
self.config["env"] = env
if self.registry and self.registry.contains(ENV_CREATOR, env):
self.env_creator = self.registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
self.config = self._default_config.copy()
self.registry = registry
self.config = _deep_update(self.config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
if logger_creator:
self._result_logger = logger_creator(self.config)
self.logdir = self._result_logger.logdir
else:
logdir_suffix = "{}_{}_{}".format(
env, self._agent_name,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
self.logdir = tempfile.mkdtemp(
prefix=logdir_suffix, dir=DEFAULT_RESULTS_DIR)
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = 0
# Merge the supplied config with the class default
merged_config = self._default_config.copy()
merged_config = _deep_update(merged_config, self.config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.config = merged_config
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
self._init()
self._initialize_ok = True
def _init(self):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
def train(self):
"""Runs one logical iteration of training.
Returns:
A TrainingResult that describes training progress.
"""
if not self._initialize_ok:
raise ValueError(
"Agent initialization failed, see previous errors")
start = time.time()
result = self._train()
self._iteration += 1
if result.time_this_iter_s is not None:
time_this_iter = result.time_this_iter_s
else:
time_this_iter = time.time() - start
assert result.timesteps_this_iter is not None
self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter
now = datetime.today()
result = result._replace(
experiment_id=self._experiment_id,
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
timestamp=int(time.mktime(now.timetuple())),
training_iteration=self._iteration,
timesteps_total=self._timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total,
pid=os.getpid(),
hostname=os.uname()[1])
self._result_logger.on_result(result)
return result
def save(self):
"""Saves the current model state to a checkpoint.
Returns:
Checkpoint path that may be passed to restore().
"""
checkpoint_path = self._save()
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
open(checkpoint_path + ".rllib_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
"""
checkpoint_prefix = self.save()
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
data[os.path.basename(path)] = open(path, "rb").read()
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
f.write(compressed)
return out.getvalue()
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
"""
self._restore(checkpoint_path)
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
self._experiment_id = metadata[0]
self._iteration = metadata[1]
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
"""
out = io.BytesIO(obj)
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
f.write(file_contents)
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def stop(self):
"""Releases all resources used by this agent."""
if self._initialize_ok:
self._result_logger.close()
self._stop()
def _stop(self):
"""Subclasses should override this for custom stopping."""
pass
def compute_action(self, observation):
"""Computes an action using the current trained policy."""
@ -283,21 +132,6 @@ class Agent(Trainable):
raise NotImplementedError
def _train(self):
"""Subclasses should override this to implement train()."""
raise NotImplementedError
def _save(self):
"""Subclasses should override this to implement save()."""
raise NotImplementedError
def _restore(self, checkpoint_path):
"""Subclasses should override this to implement restore()."""
raise NotImplementedError
class _MockAgent(Agent):
"""Mock agent for use in tests"""

View file

@ -8,6 +8,7 @@ import argparse
import sys
import yaml
import ray
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments
@ -76,7 +77,7 @@ if __name__ == "__main__":
if not exp.get("env") and not exp.get("config", {}).get("env"):
parser.error("the following arguments are required: --env")
run_experiments(
experiments, scheduler=_make_scheduler(args),
ray.init(
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
run_experiments(experiments, scheduler=_make_scheduler(args))

View file

@ -6,13 +6,10 @@ from ray.tune.error import TuneError
from ray.tune.tune import run_experiments
from ray.tune.registry import register_env, register_trainable
from ray.tune.result import TrainingResult
from ray.tune.script_runner import ScriptRunner
from ray.tune.trainable import Trainable
from ray.tune.variant_generator import grid_search
register_trainable("script", ScriptRunner)
__all__ = [
"Trainable",
"TrainingResult",

View file

@ -0,0 +1,70 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import random
import numpy as np
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
from ray.tune.hyperband import HyperBandScheduler
class MyTrainableClass(Trainable):
"""Example agent whose learning curve is a random sigmoid.
The dummy hyperparameters "width" and "height" determine the slope and
maximum reward value reached.
"""
def _setup(self):
self.timestep = 0
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
def _save(self):
path = os.path.join(self.logdir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
ray.init()
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
hyperband = HyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
max_t=100)
run_experiments({
"hyperband_test": {
"run": "my_class",
"repeat": 100,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=hyperband)

View file

@ -33,6 +33,7 @@ import sys
import tempfile
import time
import ray
from ray.tune import grid_search, run_experiments, register_trainable
from tensorflow.examples.tutorials.mnist import input_data
@ -222,4 +223,5 @@ if __name__ == '__main__':
if args.fast:
mnist_spec['stop']['training_iteration'] = 2
ray.init()
run_experiments({'tune_mnist_test': mnist_spec})

View file

@ -1,14 +0,0 @@
tune_mnist:
run: script
repeat: 2
resources:
cpu: 1
stop:
mean_accuracy: 0.99
time_total_s: 600
config:
script_file_path: examples/tune_mnist_ray.py
script_entrypoint: train
script_min_iter_time_s: 1
activation:
grid_search: ['relu', 'elu', 'tanh']

View file

@ -2,15 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import os
import sys
import time
import threading
import traceback
from ray.rllib.agent import Agent
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TrainingResult
@ -53,12 +50,6 @@ class StatusReporter(object):
DEFAULT_CONFIG = {
# path of the script to run
"script_file_path": "/path/to/file.py",
# name of train function in the file, e.g. train(config, status_reporter)
"script_entrypoint": "train",
# batch results to at least this granularity
"script_min_iter_time_s": 1,
}
@ -85,67 +76,37 @@ class _RunnerThread(threading.Thread):
self._status_reporter._done = True
def import_function(file_path, function_name):
# strong assumption here that we're in a new process
file_path = os.path.expanduser(file_path)
sys.path.insert(0, os.path.dirname(file_path))
if hasattr(importlib, "util"):
# Python 3.4+
spec = importlib.util.spec_from_file_location(
"external_file", file_path)
external_file = importlib.util.module_from_spec(spec)
spec.loader.exec_module(external_file)
elif hasattr(importlib, "machinery"):
# Python 3.3
from importlib.machinery import SourceFileLoader
external_file = SourceFileLoader(
"external_file", file_path).load_module()
else:
# Python 2.x
import imp
external_file = imp.load_source("external_file", file_path)
if not external_file:
raise TuneError("Unable to import file at {}".format(file_path))
return getattr(external_file, function_name)
class FunctionRunner(Trainable):
"""Trainable that runs a user function returning training results.
This mode of execution does not support checkpoint/restore."""
class ScriptRunner(Agent):
"""Agent that runs a user script returning training results."""
_agent_name = "script"
_name = "func"
_default_config = DEFAULT_CONFIG
_allow_unknown_configs = True
def _init(self):
def _setup(self):
entrypoint = self._trainable_func()
if not entrypoint:
entrypoint = import_function(
self.config["script_file_path"],
self.config["script_entrypoint"])
self._status_reporter = StatusReporter()
scrubbed_config = self.config.copy()
for k in self._default_config:
del scrubbed_config[k]
if k in scrubbed_config:
del scrubbed_config[k]
self._runner = _RunnerThread(
entrypoint, scrubbed_config, self._status_reporter)
self._start_time = time.time()
self._last_reported_time = self._start_time
self._last_reported_timestep = 0
self._runner.start()
# Subclasses can override this to set the trainable func
# TODO(ekl) this isn't a very clean layering, we should refactor it
def _trainable_func(self):
return None
"""Subclasses can override this to set the trainable func."""
def train(self):
if not self._initialize_ok:
raise ValueError(
"Agent initialization failed, see previous errors")
now = time.time()
time.sleep(self.config["script_min_iter_time_s"])
raise NotImplementedError
def _train(self):
time.sleep(
self.config.get(
"script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
time.sleep(1)
@ -153,29 +114,10 @@ class ScriptRunner(Agent):
if result.timesteps_total is None:
raise TuneError("Must specify timesteps_total in result", result)
# Include the negative loss to use as a stopping condition
if result.mean_loss is not None:
neg_loss = -result.mean_loss
else:
neg_loss = result.neg_mean_loss
result = result._replace(
experiment_id=self._experiment_id,
neg_mean_loss=neg_loss,
training_iteration=self.iteration,
time_this_iter_s=now - self._last_reported_time,
timesteps_this_iter=(
result.timesteps_total - self._last_reported_timestep),
time_total_s=now - self._start_time,
pid=os.getpid(),
hostname=os.uname()[1])
if result.timesteps_total:
self._last_reported_timestep = result.timesteps_total
self._last_reported_time = now
self._iteration += 1
self._result_logger.on_result(result)
result.timesteps_total - self._last_reported_timestep))
self._last_reported_timestep = result.timesteps_total
return result

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
@ -61,7 +62,8 @@ class HyperBandScheduler(FIFOScheduler):
max_t (int): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
The HyperBand scheduler automatically tries to determine a
reasonable number of brackets based on this.
reasonable number of brackets based on this. The scheduler will
terminate trials after this time has passed.
"""
def __init__(
@ -210,7 +212,8 @@ class HyperBandScheduler(FIFOScheduler):
List of trials not used since all trials are tracked as state
of scheduler. If iteration is occupied (ie, no trials to run),
then look into next iteration."""
then look into next iteration.
"""
for hyperband in self._hyperbands:
for bracket in sorted(hyperband,
@ -222,18 +225,14 @@ class HyperBandScheduler(FIFOScheduler):
return None
def debug_string(self):
# TODO(rliaw): This debug string needs work
brackets = [
"({0}/{1})".format(
len(bracket._live_trials), len(bracket._all_trials))
for band in self._hyperbands for bracket in band]
return " ".join([
"Using HyperBand:",
"num_stopped={}".format(self._num_stopped),
"total_brackets={}".format(
sum(len(band) for band in self._hyperbands)),
" ".join(brackets)
])
out = "Using HyperBand: "
out += "num_stopped={} total_brackets={}".format(
self._num_stopped, sum(len(band) for band in self._hyperbands))
for i, band in enumerate(self._hyperbands):
out += "\nRound #{}:".format(i)
for bracket in band:
out += "\n {}".format(bracket)
return out
class Bracket():
@ -370,10 +369,9 @@ class Bracket():
status = ", ".join([
"n={}".format(self._n),
"r={}".format(self._r),
"progress={}".format(self.completion_percentage())
"completed={}%".format(int(100 * self.completion_percentage()))
])
return "Bracket({})".format(status)
def debug_string(self):
trials = ", ".join([t.status for t in self._live_trials])
return "{}[{}]".format(self, trials)
counts = collections.Counter()
for t in self._all_trials:
counts[t.status] += 1
return "Bracket({}): {}".format(status, dict(counts))

View file

@ -2,55 +2,261 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import gzip
import io
import os
import pickle
import shutil
import tempfile
import time
import uuid
from ray.tune import TuneError
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
class Trainable(object):
"""Interface for trainable models, functions, etc.
"""Abstract class for trainable models, functions, etc.
Implementing this interface is required to use Ray.tune's full
functionality, though you can also get away with supplying just a
`my_train(config, reporter)` function and calling:
A call to ``train()`` on a trainable will execute one logical iteration of
training. As a rule of thumb, the execution time of one train call should
be large enough to avoid overheads (i.e. more than a few seconds), but
short enough to report progress periodically (i.e. at most a few minutes).
Calling ``save()`` should save the training state of a trainable to disk,
and ``restore(path)`` should restore a trainable to the given state.
Generally you only need to implement ``_train``, ``_save``, and
``_restore`` here when subclassing Trainable.
Note that, if you don't require checkpoint/restore functionality, then
instead of implementing this class you can also get away with supplying
just a `my_train(config, reporter)` function and calling:
``register_trainable("my_func", train)``
to register it for use with tune. The function will be automatically
converted to this interface (sans checkpoint functionality)."""
to register it for use with Tune. The function will be automatically
converted to this interface (sans checkpoint functionality).
Attributes:
config (obj): The hyperparam configuration for this trial.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
def __init__(self, config={}, registry=None, logger_creator=None):
"""Initialize an Trainable.
Subclasses should prefer defining ``_setup()`` instead of overriding
``__init__()`` directly.
Args:
config (dict): Trainable-specific configuration data.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
if registry is None:
from ray.tune.registry import get_registry
registry = get_registry()
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
self.config = config
self.registry = registry
if logger_creator:
self._result_logger = logger_creator(self.config)
self.logdir = self._result_logger.logdir
else:
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
self.logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = 0
self._setup()
self._initialize_ok = True
def train(self):
"""Runs one logical iteration of training.
Subclasses should override ``_train()`` instead to return results.
This method auto-fills many fields, so only ``timesteps_this_iter``
is requied to be present.
Returns:
A TrainingResult that describes training progress.
"""
raise NotImplementedError
if not self._initialize_ok:
raise ValueError(
"Trainable initialization failed, see previous errors")
start = time.time()
result = self._train()
self._iteration += 1
if result.time_this_iter_s is not None:
time_this_iter = result.time_this_iter_s
else:
time_this_iter = time.time() - start
if result.timesteps_this_iter is None:
raise TuneError(
"Must specify timesteps_this_iter in result", result)
self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter
# Include the negative loss to use as a stopping condition
if result.mean_loss is not None:
neg_loss = -result.mean_loss
else:
neg_loss = result.neg_mean_loss
now = datetime.today()
result = result._replace(
experiment_id=self._experiment_id,
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
timestamp=int(time.mktime(now.timetuple())),
training_iteration=self._iteration,
timesteps_total=self._timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total,
neg_mean_loss=neg_loss,
pid=os.getpid(),
hostname=os.uname()[1])
self._result_logger.on_result(result)
return result
def save(self):
"""Saves the current model state to a checkpoint.
Subclasses should override ``_save()`` instead to save state.
This method dumps additional metadata alongside the saved path.
Returns:
Checkpoint path that may be passed to restore().
"""
raise NotImplementedError
checkpoint_path = self._save()
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
open(checkpoint_path + ".tune_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
"""
checkpoint_prefix = self.save()
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
data[os.path.basename(path)] = open(path, "rb").read()
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
f.write(compressed)
return out.getvalue()
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
Subclasses should override ``_restore()`` instead to restore state.
This method restores additional metadata saved with the checkpoint.
"""
self._restore(checkpoint_path)
metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb"))
self._experiment_id = metadata[0]
self._iteration = metadata[1]
self._timesteps_total = metadata[2]
self._time_total = metadata[3]
def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
"""
out = io.BytesIO(obj)
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
f.write(file_contents)
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def stop(self):
"""Releases all resources used by this trainable."""
if self._initialize_ok:
self._result_logger.close()
self._stop()
def _train(self):
"""Subclasses should override this to implement train()."""
raise NotImplementedError
def stop(self):
"""Releases all resources used by this class."""
def _save(self):
"""Subclasses should override this to implement save()."""
raise NotImplementedError
def _restore(self, checkpoint_path):
"""Subclasses should override this to implement restore()."""
raise NotImplementedError
def _setup(self):
"""Subclasses should override this for custom initialization."""
pass
def _stop(self):
"""Subclasses should override this for any cleanup on stop."""
pass
def wrap_function(train_func):
from ray.tune.script_runner import ScriptRunner
from ray.tune.function_runner import FunctionRunner
class WrappedFunc(ScriptRunner):
class WrappedFunc(FunctionRunner):
def _trainable_func(self):
return train_func

View file

@ -2,18 +2,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from datetime import datetime
import tempfile
import time
import traceback
import ray
import os
from collections import namedtuple
from ray.utils import random_string, binary_to_hex
from ray.tune import TuneError
from ray.tune.logger import NoopLogger, UnifiedLogger
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.utils import random_string, binary_to_hex
DEBUG_PRINT_INTERVAL = 5
class Resources(
@ -106,6 +109,7 @@ class Trial(object):
self.location = None
self.logdir = None
self.result_logger = None
self.last_debug = 0
self.trial_id = binary_to_hex(random_string())[:8]
def start(self):
@ -293,8 +297,10 @@ class Trial(object):
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
self.last_result = result
self.result_logger.on_result(self.last_result)

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import ray
import time
@ -13,6 +14,9 @@ from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
MAX_DEBUG_TRIALS = 20
class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@ -127,9 +131,40 @@ class TrialRunner(object):
self._scheduler_alg.on_trial_add(self, trial)
self._trials.append(trial)
def debug_string(self):
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
"""Returns a human readable message for printing to the console."""
messages = self._debug_messages()
states = collections.defaultdict(set)
limit_per_state = collections.Counter()
for t in self._trials:
states[t.status].add(t)
# Show at most max_debug total, but divide the limit fairly
while max_debug > 0:
start_num = max_debug
for s in states:
if limit_per_state[s] >= len(states[s]):
continue
max_debug -= 1
limit_per_state[s] += 1
if max_debug == start_num:
break
for local_dir in sorted(set([t.local_dir for t in self._trials])):
messages.append("Result logdir: {}".format(local_dir))
for state, trials in sorted(states.items()):
limit = limit_per_state[state]
messages.append("{} trials:".format(state))
for t in sorted(
trials, key=lambda t: t.experiment_tag)[:limit]:
messages.append(" - {}:\t{}".format(t, t.progress_string()))
if len(trials) > limit:
messages.append(" ... {} more not shown".format(
len(trials) - limit))
return "\n".join(messages) + "\n"
def _debug_messages(self):
messages = ["== Status =="]
messages.append(self._scheduler_alg.debug_string())
if self._resources_initialized:
@ -139,13 +174,7 @@ class TrialRunner(object):
self._avail_resources.cpu,
self._committed_resources.gpu,
self._avail_resources.gpu))
for local_dir in sorted(set([t.local_dir for t in self._trials])):
messages.append("Result logdir: {}".format(local_dir))
for t in self._trials:
if t.local_dir == local_dir:
messages.append(
" - {}:\t{}".format(t, t.progress_string()))
return "\n".join(messages) + "\n"
return messages
def has_resources(self, resources):
"""Returns whether this runner has at least the specified resources."""

68
python/ray/tune/tune.py Executable file → Normal file
View file

@ -1,55 +1,19 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import sys
import ray
import time
from ray.tune import TuneError
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.trial import Trial
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.trial_runner import TrialRunner
from ray.tune.trial_scheduler import FIFOScheduler
from ray.tune.web_server import TuneServer
from ray.tune.variant_generator import generate_trials
EXAMPLE_USAGE = """
MNIST tuning example:
./tune.py -f examples/tune_mnist_ray.yaml
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Tune hyperparameters with Ray.",
epilog=EXAMPLE_USAGE)
# See also the base parser definition in ray/tune/config_parser.py
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument("--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument("--scheduler", default="FIFO", type=str,
help="FIFO, MedianStopping, or HyperBand")
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")
parser.add_argument("--server", default=False, type=bool,
help="Option to launch Tune Server")
parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT,
type=int, help="Option to launch Tune Server")
parser.add_argument("-f", "--config-file", required=True, type=str,
help="Read experiment options from this JSON/YAML file.")
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
@ -67,7 +31,11 @@ def _make_scheduler(args):
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT, **ray_args):
server_port=TuneServer.DEFAULT_PORT):
# Make sure rllib agents are registered
from ray import rllib # noqa # pylint: disable=unused-import
if scheduler is None:
scheduler = FIFOScheduler()
@ -77,13 +45,16 @@ def run_experiments(experiments, scheduler=None, with_server=False,
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
runner.add_trial(trial)
print(runner.debug_string())
ray.init(**ray_args)
print(runner.debug_string(max_debug=99999))
last_debug = 0
while not runner.is_finished():
runner.step()
print(runner.debug_string())
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
print(runner.debug_string())
last_debug = time.time()
print(runner.debug_string(max_debug=99999))
for trial in runner.get_trials():
# TODO(rliaw): What about errored?
@ -91,14 +62,3 @@ def run_experiments(experiments, scheduler=None, with_server=False,
raise TuneError("Trial did not complete", trial)
return runner.get_trials()
if __name__ == "__main__":
import yaml
args = parser.parse_args(sys.argv[1:])
with open(args.config_file) as f:
experiments = yaml.load(f)
run_experiments(
experiments, _make_scheduler(args), with_server=args.server,
server_port=args.server_port, redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)

View file

@ -20,6 +20,9 @@ from ray.tune.variant_generator import generate_trials, grid_search, \
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects

View file

@ -509,4 +509,6 @@ class HyperbandSuite(unittest.TestCase):
if __name__ == "__main__":
from ray.rllib import _register_all
_register_all()
unittest.main(verbosity=2)