mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] Ray Tune API cleanup (#1454)
Remove rllib dep: trainable is now a standalone abstract class that can be easily subclassed. Clean up hyperband: fix debug string and add an example. Remove YAML api / ScriptRunner: this was never really used. Move ray.init() out of run_experiments(): This provides greater flexibility and should be less confusing since there isn't an implicit init() done there. Note that this is a breaking API change for tune.
This commit is contained in:
parent
a1b01ee7fb
commit
173f1d629a
24 changed files with 486 additions and 421 deletions
|
@ -262,6 +262,7 @@ in the ``config`` section of the experiments.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray.tune.tune import run_experiments
|
||||
from ray.tune.variant_generator import grid_search
|
||||
|
||||
|
@ -286,6 +287,7 @@ in the ``config`` section of the experiments.
|
|||
# put additional experiments to run concurrently here
|
||||
}
|
||||
|
||||
ray.init()
|
||||
run_experiments(experiment)
|
||||
|
||||
Contributing to RLlib
|
||||
|
|
|
@ -16,8 +16,9 @@ You can find the code for Ray Tune `here on GitHub <https://github.com/ray-proje
|
|||
Getting Started
|
||||
---------------
|
||||
|
||||
::
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray.tune import register_trainable, grid_search, run_experiments
|
||||
|
||||
def my_func(config, reporter):
|
||||
|
@ -30,6 +31,7 @@ Getting Started
|
|||
|
||||
register_trainable("my_func", my_func)
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"my_experiment": {
|
||||
"run": "my_func",
|
||||
|
@ -67,7 +69,7 @@ Ray Tune logs trial results to a unique directory per experiment, e.g. ``~/ray_r
|
|||
|
||||
To visualize learning in tensorboard, run:
|
||||
|
||||
::
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install tensorboard
|
||||
$ tensorboard --logdir=~/ray_results/my_experiment
|
||||
|
@ -76,7 +78,7 @@ To visualize learning in tensorboard, run:
|
|||
|
||||
To use rllab's VisKit (you may have to install some dependencies), run:
|
||||
|
||||
::
|
||||
.. code-block:: bash
|
||||
|
||||
$ git clone https://github.com/rll/rllab.git
|
||||
$ python rllab/rllab/viskit/frontend.py ~/ray_results/my_experiment
|
||||
|
@ -85,7 +87,7 @@ To use rllab's VisKit (you may have to install some dependencies), run:
|
|||
|
||||
Finally, to view the results with a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__, open `ParalleCoordinatesVisualization.ipynb <https://github.com/ray-project/ray/blob/master/python/ray/tune/ParallelCoordinatesVisualization.ipynb>`__ as follows and run its cells:
|
||||
|
||||
::
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd $RAY_HOME/python/ray/tune
|
||||
$ jupyter-notebook ParallelCoordinatesVisualization.ipynb
|
||||
|
@ -97,7 +99,7 @@ In the above example, we specified a grid search over two parameters using the `
|
|||
|
||||
The following shows grid search over two nested parameters combined with random sampling from two lambda functions. Note that the value of ``beta`` depends on the value of ``alpha``, which is represented by referencing ``spec.config.alpha`` in the lambda function. This lets you specify conditional parameter distributions.
|
||||
|
||||
::
|
||||
.. code-block:: python
|
||||
|
||||
"config": {
|
||||
"alpha": lambda spec: np.random.uniform(100),
|
||||
|
@ -118,60 +120,71 @@ Early Stopping
|
|||
|
||||
To reduce costs, long-running trials can often be early stopped if their initial performance is not promising. Ray Tune allows early stopping algorithms to be plugged in on top of existing grid or random searches. This can be enabled by setting the ``scheduler`` parameter of ``run_experiments``, e.g.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
run_experiments({...}, scheduler=HyperBandScheduler())
|
||||
|
||||
An example of this can be found in `hyperband_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__. The progress of one such HyperBand run is shown below.
|
||||
|
||||
Note that some trial schedulers such as HyperBand require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
|
||||
|
||||
::
|
||||
|
||||
run_experiments({...}, scheduler=MedianStoppingRule())
|
||||
== Status ==
|
||||
Using HyperBand: num_stopped=0 total_brackets=5
|
||||
Round #0:
|
||||
Bracket(n=5, r=100, completed=80%): {'PAUSED': 4, 'PENDING': 1}
|
||||
Bracket(n=8, r=33, completed=23%): {'PAUSED': 4, 'PENDING': 4}
|
||||
Bracket(n=15, r=11, completed=4%): {'RUNNING': 2, 'PAUSED': 2, 'PENDING': 11}
|
||||
Bracket(n=34, r=3, completed=0%): {'RUNNING': 2, 'PENDING': 32}
|
||||
Bracket(n=81, r=1, completed=0%): {'PENDING': 38}
|
||||
Resources used: 4/4 CPUs, 0/0 GPUs
|
||||
Result logdir: /home/eric/ray_results/hyperband_test
|
||||
PAUSED trials:
|
||||
- my_class_0_height=99,width=43: PAUSED [pid=11664], 0 s, 100 ts, 97.1 rew
|
||||
- my_class_11_height=85,width=81: PAUSED [pid=11771], 0 s, 33 ts, 32.8 rew
|
||||
- my_class_12_height=0,width=52: PAUSED [pid=11785], 0 s, 33 ts, 0 rew
|
||||
- my_class_19_height=44,width=88: PAUSED [pid=11811], 0 s, 11 ts, 5.47 rew
|
||||
- my_class_27_height=96,width=84: PAUSED [pid=11840], 0 s, 11 ts, 12.5 rew
|
||||
... 5 more not shown
|
||||
PENDING trials:
|
||||
- my_class_10_height=12,width=25: PENDING
|
||||
- my_class_13_height=90,width=45: PENDING
|
||||
- my_class_14_height=69,width=45: PENDING
|
||||
- my_class_15_height=41,width=11: PENDING
|
||||
- my_class_16_height=57,width=69: PENDING
|
||||
... 81 more not shown
|
||||
RUNNING trials:
|
||||
- my_class_23_height=75,width=51: RUNNING [pid=11843], 0 s, 1 ts, 1.47 rew
|
||||
- my_class_26_height=16,width=48: RUNNING
|
||||
- my_class_31_height=40,width=10: RUNNING
|
||||
- my_class_53_height=28,width=96: RUNNING
|
||||
|
||||
Currently we support the following early stopping algorithms, or you can write your own that implements the `TrialScheduler <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial_scheduler.py>`__ interface:
|
||||
Currently we support the following early stopping algorithms, or you can write your own that implements the `TrialScheduler <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial_scheduler.py>`__ interface.
|
||||
|
||||
.. autoclass:: ray.tune.median_stopping_rule.MedianStoppingRule
|
||||
.. autoclass:: ray.tune.hyperband.HyperBandScheduler
|
||||
|
||||
Checkpointing support
|
||||
---------------------
|
||||
Trial Checkpointing
|
||||
-------------------
|
||||
|
||||
To enable checkpoint / resume, the full ``Trainable`` API must be implemented (though as shown in the examples above, you can get away with just supplying a ``train(config, reporter)`` func if you don't need checkpointing). Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand. For example, all `RLlib agents <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agent.py>`__ implement the ``Trainable`` API.
|
||||
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand.
|
||||
|
||||
.. autoclass:: ray.tune.trainable.Trainable
|
||||
:members:
|
||||
|
||||
Resource Allocation
|
||||
-------------------
|
||||
|
||||
Ray Tune runs each trial as a Ray actor, allocating the specified GPU and CPU ``resources`` to each actor (defaulting to 1 CPU per trial). A trial will not be scheduled unless at least that amount of resources is available in the cluster, preventing the cluster from being overloaded.
|
||||
|
||||
If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will also want to set ``driver_cpu_limit`` or ``driver_gpu_limit`` to tell Ray not to assign the entire resource reservation to your top-level trainable function, as described in `trial.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial.py>`__.
|
||||
|
||||
Command-line JSON/YAML API
|
||||
--------------------------
|
||||
|
||||
The JSON config passed to ``run_experiments`` can also be put in a JSON or YAML file, and the experiments run using the ``tune.py`` script. This supports the same functionality as the Python API, e.g.:
|
||||
|
||||
::
|
||||
|
||||
cd ray/python/tune
|
||||
./tune.py -f examples/tune_mnist_ray.yaml --scheduler=MedianStoppingRule
|
||||
|
||||
|
||||
For more examples of experiments described by YAML files, see `RLlib tuned examples <https://github.com/ray-project/ray/tree/master/python/ray/rllib/tuned_examples>`__.
|
||||
|
||||
Running in a large cluster
|
||||
--------------------------
|
||||
|
||||
The ``run_experiments`` also takes any arguments that ``ray.init()`` does. This can be used to pass in the redis address of a multi-node Ray cluster. For more details, check out the `tune.py script <https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py>`__.
|
||||
If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will also want to set ``driver_cpu_limit`` or ``driver_gpu_limit`` to tell Ray not to assign the entire resource reservation to your top-level trainable function, as described in `trial.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/trial.py>`__. For example, if a trainable class requires 1 GPU itself, but will launch 4 actors each using another GPU, then it should set ``"gpu": 5, "driver_gpu_limit": 1``.
|
||||
|
||||
Client API
|
||||
----------
|
||||
|
||||
You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with a flag, either from the command-line, e.g.:
|
||||
You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, start your experiment with ``with_server=True``:
|
||||
|
||||
::
|
||||
|
||||
cd ray/python/tune
|
||||
./tune.py -f examples/tune_mnist_ray.yaml --server=True --server-port=4321
|
||||
|
||||
Or within the Python API, e.g.:
|
||||
::
|
||||
.. code-block:: python
|
||||
|
||||
run_experiments({...}, with_server=True, server_port=4321)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
|
@ -25,6 +26,7 @@ env_config.update({
|
|||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-a3c": {
|
||||
"run": "A3C",
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
|
@ -25,6 +26,7 @@ env_config.update({
|
|||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-dqn": {
|
||||
"run": "DQN",
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
|
@ -25,6 +26,7 @@ env_config.update({
|
|||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-ppo": {
|
||||
"run": "PPO",
|
||||
|
|
|
@ -27,6 +27,7 @@ register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
|||
register_carla_model()
|
||||
redis_address = ray.services.get_node_ip_address() + ":6379"
|
||||
|
||||
ray.init(redis_address=redis_address)
|
||||
run_experiments({
|
||||
"carla-a3c": {
|
||||
"run": "A3C",
|
||||
|
@ -50,4 +51,4 @@ run_experiments({
|
|||
"num_workers": 2,
|
||||
},
|
||||
},
|
||||
}, redis_address=redis_address)
|
||||
})
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
|
@ -23,6 +24,7 @@ env_config.update({
|
|||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla-dqn": {
|
||||
"run": "DQN",
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments
|
||||
|
||||
from env import CarlaEnv, ENV_CONFIG
|
||||
|
@ -22,6 +23,7 @@ env_config.update({
|
|||
register_env(env_name, lambda env_config: CarlaEnv(env_config))
|
||||
register_carla_model()
|
||||
|
||||
ray.init(redirect_output=True)
|
||||
run_experiments({
|
||||
"carla": {
|
||||
"run": "PPO",
|
||||
|
@ -55,4 +57,4 @@ run_experiments({
|
|||
}
|
||||
},
|
||||
},
|
||||
}, redirect_output=True)
|
||||
})
|
||||
|
|
|
@ -8,6 +8,7 @@ import gym
|
|||
from gym.spaces import Discrete, Box
|
||||
from gym.envs.registration import EnvSpec
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
@ -41,6 +42,7 @@ class SimpleCorridor(gym.Env):
|
|||
if __name__ == "__main__":
|
||||
env_creator_name = "corridor"
|
||||
register_env(env_creator_name, lambda config: SimpleCorridor(config))
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"demo": {
|
||||
"run": "PPO",
|
||||
|
|
|
@ -2,15 +2,16 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Note: do not introduce unnecessary library dependencies here, e.g. gym
|
||||
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
|
||||
# This file is imported from the tune module in order to register RLlib agents.
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.rllib.agent import get_agent_class
|
||||
|
||||
|
||||
def _register_all():
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"]:
|
||||
try:
|
||||
from ray.rllib.agent import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
except ImportError as e:
|
||||
print("Warning: could not import {}: {}".format(key, e))
|
||||
|
|
|
@ -2,25 +2,14 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import io
|
||||
import os
|
||||
import gzip
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Note: avoid introducing unnecessary library dependencies here, e.g. gym
|
||||
# until https://github.com/ray-project/ray/issues/1144 is resolved
|
||||
import tensorflow as tf
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR, get_registry
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -66,7 +55,7 @@ class Agent(Trainable):
|
|||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Tune object registry, for registering user-defined
|
||||
registry (obj): Tune object registry which holds user-registered
|
||||
classes and objects by name.
|
||||
"""
|
||||
|
||||
|
@ -83,183 +72,43 @@ class Agent(Trainable):
|
|||
env (str): Name of the environment to use. Note that this can also
|
||||
be specified as the `env` key in config.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, it will be assumed empty.
|
||||
If unspecified, the default registry will be used.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
env = env or config.get("env")
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = env or config.get("env")
|
||||
Trainable.__init__(self, config, registry, logger_creator)
|
||||
|
||||
def _setup(self):
|
||||
env = self._env_id
|
||||
if env:
|
||||
config["env"] = env
|
||||
if registry and registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = registry.get(ENV_CREATOR, env)
|
||||
self.config["env"] = env
|
||||
if self.registry and self.registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = self.registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda env_config: gym.make(env)
|
||||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
self.config = self._default_config.copy()
|
||||
self.registry = registry
|
||||
|
||||
self.config = _deep_update(self.config, config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
self.logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_suffix = "{}_{}_{}".format(
|
||||
env, self._agent_name,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_suffix, dir=DEFAULT_RESULTS_DIR)
|
||||
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
|
||||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = 0
|
||||
# Merge the supplied config with the class default
|
||||
merged_config = self._default_config.copy()
|
||||
merged_config = _deep_update(merged_config, self.config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
self.config = merged_config
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
with tf.Graph().as_default():
|
||||
self._init()
|
||||
|
||||
self._initialize_ok = True
|
||||
|
||||
def _init(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Agent initialization failed, see previous errors")
|
||||
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
self._iteration += 1
|
||||
if result.time_this_iter_s is not None:
|
||||
time_this_iter = result.time_this_iter_s
|
||||
else:
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
assert result.timesteps_this_iter is not None
|
||||
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
|
||||
now = datetime.today()
|
||||
result = result._replace(
|
||||
experiment_id=self._experiment_id,
|
||||
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
timestamp=int(time.mktime(now.timetuple())),
|
||||
training_iteration=self._iteration,
|
||||
timesteps_total=self._timesteps_total,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self._time_total,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
checkpoint_path = self._save()
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".rllib_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object. It also
|
||||
saves to disk but does not return the checkpoint path.
|
||||
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_prefix = self.save()
|
||||
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
data[os.path.basename(path)] = open(path, "rb").read()
|
||||
|
||||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
"""
|
||||
|
||||
self._restore(checkpoint_path)
|
||||
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
|
||||
self._experiment_id = metadata[0]
|
||||
self._iteration = metadata[1]
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
"""Restores training state from a checkpoint object.
|
||||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
|
||||
out = io.BytesIO(obj)
|
||||
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for file_name, file_contents in data.items():
|
||||
with open(os.path.join(tmpdir, file_name), "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this agent."""
|
||||
|
||||
if self._initialize_ok:
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
def _stop(self):
|
||||
"""Subclasses should override this for custom stopping."""
|
||||
|
||||
pass
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
||||
|
@ -283,21 +132,6 @@ class Agent(Trainable):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _save(self):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _MockAgent(Agent):
|
||||
"""Mock agent for use in tests"""
|
||||
|
|
|
@ -8,6 +8,7 @@ import argparse
|
|||
import sys
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
from ray.tune.config_parser import make_parser, resources_to_json
|
||||
from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
|
@ -76,7 +77,7 @@ if __name__ == "__main__":
|
|||
if not exp.get("env") and not exp.get("config", {}).get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
|
||||
run_experiments(
|
||||
experiments, scheduler=_make_scheduler(args),
|
||||
ray.init(
|
||||
redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
run_experiments(experiments, scheduler=_make_scheduler(args))
|
||||
|
|
|
@ -6,13 +6,10 @@ from ray.tune.error import TuneError
|
|||
from ray.tune.tune import run_experiments
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.script_runner import ScriptRunner
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.variant_generator import grid_search
|
||||
|
||||
|
||||
register_trainable("script", ScriptRunner)
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"TrainingResult",
|
||||
|
|
70
python/ray/tune/examples/hyperband_example.py
Executable file
70
python/ray/tune/examples/hyperband_example.py
Executable file
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, TrainingResult, register_trainable, \
|
||||
run_experiments
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Example agent whose learning curve is a random sigmoid.
|
||||
|
||||
The dummy hyperparameters "width" and "height" determine the slope and
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy (see tune/result.py).
|
||||
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
|
||||
|
||||
def _save(self):
|
||||
path = os.path.join(self.logdir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
max_t=100)
|
||||
|
||||
run_experiments({
|
||||
"hyperband_test": {
|
||||
"run": "my_class",
|
||||
"repeat": 100,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
}, scheduler=hyperband)
|
|
@ -33,6 +33,7 @@ import sys
|
|||
import tempfile
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import grid_search, run_experiments, register_trainable
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
@ -222,4 +223,5 @@ if __name__ == '__main__':
|
|||
if args.fast:
|
||||
mnist_spec['stop']['training_iteration'] = 2
|
||||
|
||||
ray.init()
|
||||
run_experiments({'tune_mnist_test': mnist_spec})
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
tune_mnist:
|
||||
run: script
|
||||
repeat: 2
|
||||
resources:
|
||||
cpu: 1
|
||||
stop:
|
||||
mean_accuracy: 0.99
|
||||
time_total_s: 600
|
||||
config:
|
||||
script_file_path: examples/tune_mnist_ray.py
|
||||
script_entrypoint: train
|
||||
script_min_iter_time_s: 1
|
||||
activation:
|
||||
grid_search: ['relu', 'elu', 'tanh']
|
|
@ -2,15 +2,12 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
|
@ -53,12 +50,6 @@ class StatusReporter(object):
|
|||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# path of the script to run
|
||||
"script_file_path": "/path/to/file.py",
|
||||
|
||||
# name of train function in the file, e.g. train(config, status_reporter)
|
||||
"script_entrypoint": "train",
|
||||
|
||||
# batch results to at least this granularity
|
||||
"script_min_iter_time_s": 1,
|
||||
}
|
||||
|
@ -85,67 +76,37 @@ class _RunnerThread(threading.Thread):
|
|||
self._status_reporter._done = True
|
||||
|
||||
|
||||
def import_function(file_path, function_name):
|
||||
# strong assumption here that we're in a new process
|
||||
file_path = os.path.expanduser(file_path)
|
||||
sys.path.insert(0, os.path.dirname(file_path))
|
||||
if hasattr(importlib, "util"):
|
||||
# Python 3.4+
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"external_file", file_path)
|
||||
external_file = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(external_file)
|
||||
elif hasattr(importlib, "machinery"):
|
||||
# Python 3.3
|
||||
from importlib.machinery import SourceFileLoader
|
||||
external_file = SourceFileLoader(
|
||||
"external_file", file_path).load_module()
|
||||
else:
|
||||
# Python 2.x
|
||||
import imp
|
||||
external_file = imp.load_source("external_file", file_path)
|
||||
if not external_file:
|
||||
raise TuneError("Unable to import file at {}".format(file_path))
|
||||
return getattr(external_file, function_name)
|
||||
class FunctionRunner(Trainable):
|
||||
"""Trainable that runs a user function returning training results.
|
||||
|
||||
This mode of execution does not support checkpoint/restore."""
|
||||
|
||||
class ScriptRunner(Agent):
|
||||
"""Agent that runs a user script returning training results."""
|
||||
|
||||
_agent_name = "script"
|
||||
_name = "func"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_configs = True
|
||||
|
||||
def _init(self):
|
||||
def _setup(self):
|
||||
entrypoint = self._trainable_func()
|
||||
if not entrypoint:
|
||||
entrypoint = import_function(
|
||||
self.config["script_file_path"],
|
||||
self.config["script_entrypoint"])
|
||||
self._status_reporter = StatusReporter()
|
||||
scrubbed_config = self.config.copy()
|
||||
for k in self._default_config:
|
||||
del scrubbed_config[k]
|
||||
if k in scrubbed_config:
|
||||
del scrubbed_config[k]
|
||||
self._runner = _RunnerThread(
|
||||
entrypoint, scrubbed_config, self._status_reporter)
|
||||
self._start_time = time.time()
|
||||
self._last_reported_time = self._start_time
|
||||
self._last_reported_timestep = 0
|
||||
self._runner.start()
|
||||
|
||||
# Subclasses can override this to set the trainable func
|
||||
# TODO(ekl) this isn't a very clean layering, we should refactor it
|
||||
def _trainable_func(self):
|
||||
return None
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
|
||||
def train(self):
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Agent initialization failed, see previous errors")
|
||||
|
||||
now = time.time()
|
||||
time.sleep(self.config["script_min_iter_time_s"])
|
||||
raise NotImplementedError
|
||||
|
||||
def _train(self):
|
||||
time.sleep(
|
||||
self.config.get(
|
||||
"script_min_iter_time_s",
|
||||
self._default_config["script_min_iter_time_s"]))
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
while result is None:
|
||||
time.sleep(1)
|
||||
|
@ -153,29 +114,10 @@ class ScriptRunner(Agent):
|
|||
if result.timesteps_total is None:
|
||||
raise TuneError("Must specify timesteps_total in result", result)
|
||||
|
||||
# Include the negative loss to use as a stopping condition
|
||||
if result.mean_loss is not None:
|
||||
neg_loss = -result.mean_loss
|
||||
else:
|
||||
neg_loss = result.neg_mean_loss
|
||||
|
||||
result = result._replace(
|
||||
experiment_id=self._experiment_id,
|
||||
neg_mean_loss=neg_loss,
|
||||
training_iteration=self.iteration,
|
||||
time_this_iter_s=now - self._last_reported_time,
|
||||
timesteps_this_iter=(
|
||||
result.timesteps_total - self._last_reported_timestep),
|
||||
time_total_s=now - self._start_time,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
|
||||
if result.timesteps_total:
|
||||
self._last_reported_timestep = result.timesteps_total
|
||||
self._last_reported_time = now
|
||||
self._iteration += 1
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
result.timesteps_total - self._last_reported_timestep))
|
||||
self._last_reported_timestep = result.timesteps_total
|
||||
|
||||
return result
|
||||
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
@ -61,7 +62,8 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
max_t (int): max time units per trial. Trials will be stopped after
|
||||
max_t time units (determined by time_attr) have passed.
|
||||
The HyperBand scheduler automatically tries to determine a
|
||||
reasonable number of brackets based on this.
|
||||
reasonable number of brackets based on this. The scheduler will
|
||||
terminate trials after this time has passed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -210,7 +212,8 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
|
||||
List of trials not used since all trials are tracked as state
|
||||
of scheduler. If iteration is occupied (ie, no trials to run),
|
||||
then look into next iteration."""
|
||||
then look into next iteration.
|
||||
"""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
|
@ -222,18 +225,14 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
return None
|
||||
|
||||
def debug_string(self):
|
||||
# TODO(rliaw): This debug string needs work
|
||||
brackets = [
|
||||
"({0}/{1})".format(
|
||||
len(bracket._live_trials), len(bracket._all_trials))
|
||||
for band in self._hyperbands for bracket in band]
|
||||
return " ".join([
|
||||
"Using HyperBand:",
|
||||
"num_stopped={}".format(self._num_stopped),
|
||||
"total_brackets={}".format(
|
||||
sum(len(band) for band in self._hyperbands)),
|
||||
" ".join(brackets)
|
||||
])
|
||||
out = "Using HyperBand: "
|
||||
out += "num_stopped={} total_brackets={}".format(
|
||||
self._num_stopped, sum(len(band) for band in self._hyperbands))
|
||||
for i, band in enumerate(self._hyperbands):
|
||||
out += "\nRound #{}:".format(i)
|
||||
for bracket in band:
|
||||
out += "\n {}".format(bracket)
|
||||
return out
|
||||
|
||||
|
||||
class Bracket():
|
||||
|
@ -370,10 +369,9 @@ class Bracket():
|
|||
status = ", ".join([
|
||||
"n={}".format(self._n),
|
||||
"r={}".format(self._r),
|
||||
"progress={}".format(self.completion_percentage())
|
||||
"completed={}%".format(int(100 * self.completion_percentage()))
|
||||
])
|
||||
return "Bracket({})".format(status)
|
||||
|
||||
def debug_string(self):
|
||||
trials = ", ".join([t.status for t in self._live_trials])
|
||||
return "{}[{}]".format(self, trials)
|
||||
counts = collections.Counter()
|
||||
for t in self._all_trials:
|
||||
counts[t.status] += 1
|
||||
return "Bracket({}): {}".format(status, dict(counts))
|
||||
|
|
|
@ -2,55 +2,261 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import gzip
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
class Trainable(object):
|
||||
"""Interface for trainable models, functions, etc.
|
||||
"""Abstract class for trainable models, functions, etc.
|
||||
|
||||
Implementing this interface is required to use Ray.tune's full
|
||||
functionality, though you can also get away with supplying just a
|
||||
`my_train(config, reporter)` function and calling:
|
||||
A call to ``train()`` on a trainable will execute one logical iteration of
|
||||
training. As a rule of thumb, the execution time of one train call should
|
||||
be large enough to avoid overheads (i.e. more than a few seconds), but
|
||||
short enough to report progress periodically (i.e. at most a few minutes).
|
||||
|
||||
Calling ``save()`` should save the training state of a trainable to disk,
|
||||
and ``restore(path)`` should restore a trainable to the given state.
|
||||
|
||||
Generally you only need to implement ``_train``, ``_save``, and
|
||||
``_restore`` here when subclassing Trainable.
|
||||
|
||||
Note that, if you don't require checkpoint/restore functionality, then
|
||||
instead of implementing this class you can also get away with supplying
|
||||
just a `my_train(config, reporter)` function and calling:
|
||||
|
||||
``register_trainable("my_func", train)``
|
||||
|
||||
to register it for use with tune. The function will be automatically
|
||||
converted to this interface (sans checkpoint functionality)."""
|
||||
to register it for use with Tune. The function will be automatically
|
||||
converted to this interface (sans checkpoint functionality).
|
||||
|
||||
Attributes:
|
||||
config (obj): The hyperparam configuration for this trial.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Tune object registry which holds user-registered
|
||||
classes and objects by name.
|
||||
"""
|
||||
|
||||
def __init__(self, config={}, registry=None, logger_creator=None):
|
||||
"""Initialize an Trainable.
|
||||
|
||||
Subclasses should prefer defining ``_setup()`` instead of overriding
|
||||
``__init__()`` directly.
|
||||
|
||||
Args:
|
||||
config (dict): Trainable-specific configuration data.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, the default registry will be used.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
if registry is None:
|
||||
from ray.tune.registry import get_registry
|
||||
registry = get_registry()
|
||||
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
self.config = config
|
||||
self.registry = registry
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
self.logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
|
||||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = 0
|
||||
self._setup()
|
||||
self._initialize_ok = True
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Subclasses should override ``_train()`` instead to return results.
|
||||
This method auto-fills many fields, so only ``timesteps_this_iter``
|
||||
is requied to be present.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Trainable initialization failed, see previous errors")
|
||||
|
||||
start = time.time()
|
||||
result = self._train()
|
||||
self._iteration += 1
|
||||
if result.time_this_iter_s is not None:
|
||||
time_this_iter = result.time_this_iter_s
|
||||
else:
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
if result.timesteps_this_iter is None:
|
||||
raise TuneError(
|
||||
"Must specify timesteps_this_iter in result", result)
|
||||
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
|
||||
# Include the negative loss to use as a stopping condition
|
||||
if result.mean_loss is not None:
|
||||
neg_loss = -result.mean_loss
|
||||
else:
|
||||
neg_loss = result.neg_mean_loss
|
||||
|
||||
now = datetime.today()
|
||||
result = result._replace(
|
||||
experiment_id=self._experiment_id,
|
||||
date=now.strftime("%Y-%m-%d_%H-%M-%S"),
|
||||
timestamp=int(time.mktime(now.timetuple())),
|
||||
training_iteration=self._iteration,
|
||||
timesteps_total=self._timesteps_total,
|
||||
time_this_iter_s=time_this_iter,
|
||||
time_total_s=self._time_total,
|
||||
neg_mean_loss=neg_loss,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
return result
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Subclasses should override ``_save()`` instead to save state.
|
||||
This method dumps additional metadata alongside the saved path.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
checkpoint_path = self._save()
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object. It also
|
||||
saves to disk but does not return the checkpoint path.
|
||||
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
"""
|
||||
|
||||
checkpoint_prefix = self.save()
|
||||
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
data[os.path.basename(path)] = open(path, "rb").read()
|
||||
|
||||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
return out.getvalue()
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
|
||||
Subclasses should override ``_restore()`` instead to restore state.
|
||||
This method restores additional metadata saved with the checkpoint.
|
||||
"""
|
||||
|
||||
self._restore(checkpoint_path)
|
||||
metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb"))
|
||||
self._experiment_id = metadata[0]
|
||||
self._iteration = metadata[1]
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
"""Restores training state from a checkpoint object.
|
||||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
|
||||
out = io.BytesIO(obj)
|
||||
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for file_name, file_contents in data.items():
|
||||
with open(os.path.join(tmpdir, file_name), "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this trainable."""
|
||||
|
||||
if self._initialize_ok:
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this class."""
|
||||
def _save(self):
|
||||
"""Subclasses should override this to implement save()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
"""Subclasses should override this to implement restore()."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _setup(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
pass
|
||||
|
||||
def _stop(self):
|
||||
"""Subclasses should override this for any cleanup on stop."""
|
||||
pass
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
from ray.tune.script_runner import ScriptRunner
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
|
||||
class WrappedFunc(ScriptRunner):
|
||||
class WrappedFunc(FunctionRunner):
|
||||
def _trainable_func(self):
|
||||
return train_func
|
||||
|
||||
|
|
|
@ -2,18 +2,21 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import ray
|
||||
import os
|
||||
|
||||
from collections import namedtuple
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import NoopLogger, UnifiedLogger
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
||||
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
||||
from ray.utils import random_string, binary_to_hex
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
|
||||
|
||||
class Resources(
|
||||
|
@ -106,6 +109,7 @@ class Trial(object):
|
|||
self.location = None
|
||||
self.logdir = None
|
||||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
|
||||
def start(self):
|
||||
|
@ -293,8 +297,10 @@ class Trial(object):
|
|||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
self.last_result = result
|
||||
self.result_logger.on_result(self.last_result)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import ray
|
||||
import time
|
||||
|
@ -13,6 +14,9 @@ from ray.tune.trial import Trial, Resources
|
|||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
|
@ -127,9 +131,40 @@ class TrialRunner(object):
|
|||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
self._trials.append(trial)
|
||||
|
||||
def debug_string(self):
|
||||
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
messages = self._debug_messages()
|
||||
states = collections.defaultdict(set)
|
||||
limit_per_state = collections.Counter()
|
||||
for t in self._trials:
|
||||
states[t.status].add(t)
|
||||
|
||||
# Show at most max_debug total, but divide the limit fairly
|
||||
while max_debug > 0:
|
||||
start_num = max_debug
|
||||
for s in states:
|
||||
if limit_per_state[s] >= len(states[s]):
|
||||
continue
|
||||
max_debug -= 1
|
||||
limit_per_state[s] += 1
|
||||
if max_debug == start_num:
|
||||
break
|
||||
|
||||
for local_dir in sorted(set([t.local_dir for t in self._trials])):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
messages.append("{} trials:".format(state))
|
||||
for t in sorted(
|
||||
trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
messages.append(" - {}:\t{}".format(t, t.progress_string()))
|
||||
if len(trials) > limit:
|
||||
messages.append(" ... {} more not shown".format(
|
||||
len(trials) - limit))
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _debug_messages(self):
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
if self._resources_initialized:
|
||||
|
@ -139,13 +174,7 @@ class TrialRunner(object):
|
|||
self._avail_resources.cpu,
|
||||
self._committed_resources.gpu,
|
||||
self._avail_resources.gpu))
|
||||
for local_dir in sorted(set([t.local_dir for t in self._trials])):
|
||||
messages.append("Result logdir: {}".format(local_dir))
|
||||
for t in self._trials:
|
||||
if t.local_dir == local_dir:
|
||||
messages.append(
|
||||
" - {}:\t{}".format(t, t.progress_string()))
|
||||
return "\n".join(messages) + "\n"
|
||||
return messages
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
|
68
python/ray/tune/tune.py
Executable file → Normal file
68
python/ray/tune/tune.py
Executable file → Normal file
|
@ -1,55 +1,19 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import ray
|
||||
import time
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.variant_generator import generate_trials
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
MNIST tuning example:
|
||||
./tune.py -f examples/tune_mnist_ray.yaml
|
||||
"""
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Tune hyperparameters with Ray.",
|
||||
epilog=EXAMPLE_USAGE)
|
||||
|
||||
# See also the base parser definition in ray/tune/config_parser.py
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--num-cpus", default=None, type=int,
|
||||
help="Number of CPUs to allocate to Ray.")
|
||||
parser.add_argument("--num-gpus", default=None, type=int,
|
||||
help="Number of GPUs to allocate to Ray.")
|
||||
parser.add_argument("--scheduler", default="FIFO", type=str,
|
||||
help="FIFO, MedianStopping, or HyperBand")
|
||||
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
parser.add_argument("--server", default=False, type=bool,
|
||||
help="Option to launch Tune Server")
|
||||
parser.add_argument("--server-port", default=TuneServer.DEFAULT_PORT,
|
||||
type=int, help="Option to launch Tune Server")
|
||||
parser.add_argument("-f", "--config-file", required=True, type=str,
|
||||
help="Read experiment options from this JSON/YAML file.")
|
||||
|
||||
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
|
@ -67,7 +31,11 @@ def _make_scheduler(args):
|
|||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, **ray_args):
|
||||
server_port=TuneServer.DEFAULT_PORT):
|
||||
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa # pylint: disable=unused-import
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
|
@ -77,13 +45,16 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
runner.add_trial(trial)
|
||||
print(runner.debug_string())
|
||||
|
||||
ray.init(**ray_args)
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
last_debug = 0
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
print(runner.debug_string())
|
||||
if time.time() - last_debug > DEBUG_PRINT_INTERVAL:
|
||||
print(runner.debug_string())
|
||||
last_debug = time.time()
|
||||
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
for trial in runner.get_trials():
|
||||
# TODO(rliaw): What about errored?
|
||||
|
@ -91,14 +62,3 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
raise TuneError("Trial did not complete", trial)
|
||||
|
||||
return runner.get_trials()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import yaml
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
with open(args.config_file) as f:
|
||||
experiments = yaml.load(f)
|
||||
run_experiments(
|
||||
experiments, _make_scheduler(args), with_server=args.server,
|
||||
server_port=args.server_port, redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
|
|
|
@ -20,6 +20,9 @@ from ray.tune.variant_generator import generate_trials, grid_search, \
|
|||
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
|
|
@ -509,4 +509,6 @@ class HyperbandSuite(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Add table
Reference in a new issue