mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] clean up population based training prototype (#1478)
* patch up pbt * Sat Jan 27 01:00:03 PST 2018 * Sat Jan 27 01:04:14 PST 2018 * Sat Jan 27 01:04:21 PST 2018 * Sat Jan 27 01:15:15 PST 2018 * Sat Jan 27 01:15:42 PST 2018 * Sat Jan 27 01:16:14 PST 2018 * Sat Jan 27 01:38:42 PST 2018 * Sat Jan 27 01:39:21 PST 2018 * add pbt * Sat Jan 27 01:41:19 PST 2018 * Sat Jan 27 01:44:21 PST 2018 * Sat Jan 27 01:45:46 PST 2018 * Sat Jan 27 16:54:42 PST 2018 * Sat Jan 27 16:57:53 PST 2018 * clean up test * Sat Jan 27 18:01:15 PST 2018 * Sat Jan 27 18:02:54 PST 2018 * Sat Jan 27 18:11:18 PST 2018 * Sat Jan 27 18:11:55 PST 2018 * Sat Jan 27 18:14:09 PST 2018 * review * try out a ppo example * some tweaks to ppo example * add postprocess hook * Sun Jan 28 15:00:40 PST 2018 * clean up custom explore fn * Sun Jan 28 15:10:21 PST 2018 * Sun Jan 28 15:14:53 PST 2018 * Sun Jan 28 15:17:04 PST 2018 * Sun Jan 28 15:33:13 PST 2018 * Sun Jan 28 15:56:40 PST 2018 * Sun Jan 28 15:57:36 PST 2018 * Sun Jan 28 16:00:35 PST 2018 * Sun Jan 28 16:02:58 PST 2018 * Sun Jan 28 16:29:50 PST 2018 * Sun Jan 28 16:30:36 PST 2018 * Sun Jan 28 16:31:44 PST 2018 * improve tune doc * concepts * update humanoid * Fri Feb 2 18:03:33 PST 2018 * fix example * show error file
This commit is contained in:
parent
a936468f99
commit
b948405532
22 changed files with 698 additions and 288 deletions
11
.travis.yml
11
.travis.yml
|
@ -128,13 +128,18 @@ script:
|
|||
- python test/multi_node_test.py
|
||||
- python test/recursion_test.py
|
||||
- python test/monitor_test.py
|
||||
- python test/trial_runner_test.py
|
||||
- python test/trial_scheduler_test.py
|
||||
- python test/tune_server_test.py
|
||||
- python test/cython_test.py
|
||||
|
||||
# ray dataframe tests
|
||||
- python -m pytest python/ray/dataframe/test/test_dataframe.py
|
||||
- python -m pytest python/ray/dataframe/test/test_series.py
|
||||
|
||||
# ray tune tests
|
||||
- python -m pytest python/ray/tune/test/trial_runner_test.py
|
||||
- python -m pytest python/ray/tune/test/trial_scheduler_test.py
|
||||
- python -m pytest python/ray/tune/test/tune_server_test.py
|
||||
|
||||
# ray rllib tests
|
||||
- python -m pytest python/ray/rllib/test/test_catalog.py
|
||||
- python -m pytest python/ray/rllib/test/test_filters.py
|
||||
- python -m pytest python/ray/rllib/test/test_optimizers.py
|
||||
|
|
BIN
doc/source/pbt.png
Normal file
BIN
doc/source/pbt.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 32 KiB |
|
@ -290,6 +290,9 @@ in the ``config`` section of the experiments.
|
|||
ray.init()
|
||||
run_experiments(experiment)
|
||||
|
||||
For an advanced example of using Population Based Training (PBT) with RLlib,
|
||||
see the `PPO + PBT Walker2D training example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__.
|
||||
|
||||
Contributing to RLlib
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
Ray Tune: Hyperparameter Optimization Framework
|
||||
===============================================
|
||||
|
||||
This document describes Ray Tune, a hyperparameter tuning framework for long-running tasks such as RL and deep learning training. It has the following features:
|
||||
This document describes Ray Tune, a hyperparameter tuning framework for long-running tasks such as RL and deep learning training. Ray Tune makes it easy to go from running one or more experiments on a single machine to running on a large cluster with efficient search algorithms.
|
||||
|
||||
- Early stopping algorithms such as `Median Stopping Rule <https://research.google.com/pubs/pub46180.html>`__ and `HyperBand <https://arxiv.org/abs/1603.06560>`__.
|
||||
It has the following features:
|
||||
|
||||
- Scalable implementations of search algorithms such as `Population Based Training (PBT) <#population-based-training>`__, `Median Stopping Rule <https://research.google.com/pubs/pub46180.html>`__, and `HyperBand <https://arxiv.org/abs/1603.06560>`__.
|
||||
|
||||
- Integration with visualization tools such as `TensorBoard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`__, `rllab's VisKit <https://media.readthedocs.org/pdf/rllab/latest/rllab.pdf>`__, and a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__.
|
||||
|
||||
|
@ -11,8 +13,18 @@ This document describes Ray Tune, a hyperparameter tuning framework for long-run
|
|||
|
||||
- Resource-aware scheduling, including support for concurrent runs of algorithms that may themselves be parallel and distributed.
|
||||
|
||||
|
||||
You can find the code for Ray Tune `here on GitHub <https://github.com/ray-project/ray/tree/master/python/ray/tune>`__.
|
||||
|
||||
Concepts
|
||||
--------
|
||||
|
||||
Ray Tune schedules a number of *trials* in a cluster. Each trial runs a user-defined Python function or class and is parameterized by a json *config* variation passed to the user code.
|
||||
|
||||
Ray Tune provides a ``run_experiments(spec)`` function that generates and runs the trials described by the experiment specification. The trials are scheduled and managed by a *trial scheduler* that implements the search algorithm (default is FIFO).
|
||||
|
||||
Ray Tune can be used anywhere Ray can, e.g. on your laptop with ``ray.init()`` embedded in a Python script, or in an `auto-scaling cluster <autoscaling.html>`__ for massive parallelism.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
|
@ -133,7 +145,7 @@ To reduce costs, long-running trials can often be early stopped if their initial
|
|||
|
||||
An example of this can be found in `hyperband_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__. The progress of one such HyperBand run is shown below.
|
||||
|
||||
Note that some trial schedulers such as HyperBand require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
|
||||
Note that some trial schedulers such as HyperBand and PBT require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
|
||||
|
||||
::
|
||||
|
||||
|
@ -172,10 +184,19 @@ Currently we support the following early stopping algorithms, or you can write y
|
|||
.. autoclass:: ray.tune.median_stopping_rule.MedianStoppingRule
|
||||
.. autoclass:: ray.tune.hyperband.HyperBandScheduler
|
||||
|
||||
Population Based Training
|
||||
-------------------------
|
||||
|
||||
Ray Tune includes a distributed implementation of `Population Based Training (PBT) <https://deepmind.com/blog/population-based-training-neural-networks>`__. PBT also requires your Trainable to support checkpointing. You can run this `toy PBT example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py>`__ to get an idea of how how PBT operates. When training in PBT mode, the set of trial variations is treated as the population, so a single trial may see many different hyperparameters over its lifetime, which is recorded in the ``result.json`` file. The following figure generated by the example shows PBT discovering new hyperparams over the course of a single experiment:
|
||||
|
||||
.. image:: pbt.png
|
||||
|
||||
.. autoclass:: ray.tune.pbt.PopulationBasedTraining
|
||||
|
||||
Trial Checkpointing
|
||||
-------------------
|
||||
|
||||
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand.
|
||||
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand and PBT.
|
||||
|
||||
.. autoclass:: ray.tune.trainable.Trainable
|
||||
|
||||
|
|
|
@ -76,13 +76,13 @@ class PPOEvaluator(Evaluator):
|
|||
# Value function predictions before the policy update.
|
||||
self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))
|
||||
|
||||
assert config["sgd_batchsize"] % len(devices) == 0, \
|
||||
"Batch size must be evenly divisible by devices"
|
||||
if is_remote:
|
||||
self.batch_size = config["rollout_batchsize"]
|
||||
self.per_device_batch_size = config["rollout_batchsize"]
|
||||
else:
|
||||
self.batch_size = config["sgd_batchsize"]
|
||||
self.batch_size = int(
|
||||
config["sgd_batchsize"] / len(devices)) * len(devices)
|
||||
assert self.batch_size % len(devices) == 0
|
||||
self.per_device_batch_size = int(self.batch_size / len(devices))
|
||||
|
||||
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):
|
||||
|
|
0
python/ray/rllib/test/test_checkpoint_restore.py
Executable file → Normal file
0
python/ray/rllib/test/test_checkpoint_restore.py
Executable file → Normal file
|
@ -4,6 +4,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
@ -49,6 +50,10 @@ class MyTrainableClass(Trainable):
|
|||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
|
@ -60,7 +65,8 @@ if __name__ == "__main__":
|
|||
run_experiments({
|
||||
"hyperband_test": {
|
||||
"run": "my_class",
|
||||
"repeat": 100,
|
||||
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
|
||||
"repeat": 20,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
|
|
88
python/ray/tune/examples/pbt_example.py
Executable file
88
python/ray/tune/examples/pbt_example.py
Executable file
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, TrainingResult, register_trainable, \
|
||||
run_experiments
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Fake agent whose learning rate is determined by dummy factors."""
|
||||
|
||||
def _setup(self):
|
||||
self.timestep = 0
|
||||
self.current_value = 0.0
|
||||
|
||||
def _train(self):
|
||||
time.sleep(0.1)
|
||||
|
||||
# Reward increase is parabolic as a function of factor_2, with a
|
||||
# maxima around factor_1=10.0.
|
||||
self.current_value += max(
|
||||
0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0))
|
||||
|
||||
# Flat increase by factor_2
|
||||
self.current_value += random.gauss(self.config["factor_2"], 1.0)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy (see tune/result.py).
|
||||
return TrainingResult(
|
||||
episode_reward_mean=self.current_value, timesteps_this_iter=1)
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps(
|
||||
{"timestep": self.timestep, "value": self.current_value}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.timestep = data["timestep"]
|
||||
self.current_value = data["value"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=10,
|
||||
hyperparam_mutations={
|
||||
# Allow for scaling-based perturbations, with a uniform backing
|
||||
# distribution for resampling.
|
||||
"factor_1": lambda config: random.uniform(0.0, 20.0),
|
||||
# Only allows resampling from this list as a perturbation.
|
||||
"factor_2": [1, 2],
|
||||
})
|
||||
|
||||
# Try to find the best factor 1 and factor 2
|
||||
run_experiments({
|
||||
"pbt_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
|
||||
"repeat": 10,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"factor_1": 4.0,
|
||||
"factor_2": 1.0,
|
||||
},
|
||||
}
|
||||
}, scheduler=pbt, verbose=False)
|
71
python/ray/tune/examples/pbt_ppo_example.py
Executable file
71
python/ray/tune/examples/pbt_ppo_example.py
Executable file
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""Example of using PBT with RLlib.
|
||||
|
||||
Note that this requires a cluster with at least 8 GPUs in order for all trials
|
||||
to run concurrently, otherwise PBT will round-robin train the trials which
|
||||
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Postprocess the perturbed config to ensure it's still valid
|
||||
def explore(config):
|
||||
# ensure we collect enough timesteps to do sgd
|
||||
if config["timesteps_per_batch"] < config["sgd_batchsize"] * 2:
|
||||
config["timesteps_per_batch"] = config["sgd_batchsize"] * 2
|
||||
# ensure we run at least one sgd iter
|
||||
if config["num_sgd_iter"] < 1:
|
||||
config["num_sgd_iter"] = 1
|
||||
return config
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=120,
|
||||
resample_probability=0.25,
|
||||
# Specifies the resampling distributions of these hyperparams
|
||||
hyperparam_mutations={
|
||||
"lambda": lambda config: random.uniform(0.9, 1.0),
|
||||
"clip_param": lambda config: random.uniform(0.01, 0.5),
|
||||
"sgd_stepsize": lambda config: random.uniform(.00001, .001),
|
||||
"num_sgd_iter": lambda config: random.randint(1, 30),
|
||||
"sgd_batchsize": lambda config: random.randint(128, 16384),
|
||||
"timesteps_per_batch":
|
||||
lambda config: random.randint(2000, 160000),
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"pbt_humanoid_test": {
|
||||
"run": "PPO",
|
||||
"env": "Humanoid-v1",
|
||||
"repeat": 8,
|
||||
"resources": {"cpu": 4, "gpu": 1},
|
||||
"config": {
|
||||
"kl_coeff": 1.0,
|
||||
"num_workers": 8,
|
||||
"devices": ["/gpu:0"],
|
||||
"model": {"free_log_std": True},
|
||||
# These params are tuned from their starting value
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
# Start off with several random variations
|
||||
"sgd_stepsize": lambda spec: random.uniform(.00001, .001),
|
||||
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
lambda spec: random.choice([10000, 20000, 40000])
|
||||
},
|
||||
},
|
||||
}, scheduler=pbt)
|
|
@ -205,7 +205,7 @@ def train(config={'activation': 'relu'}, reporter=None):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--fast', action='store_true', help='Finish quickly for testing')
|
||||
'--smoke-test', action='store_true', help='Finish quickly for testing')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
register_trainable('train_mnist', train)
|
||||
|
@ -220,7 +220,7 @@ if __name__ == '__main__':
|
|||
},
|
||||
}
|
||||
|
||||
if args.fast:
|
||||
if args.smoke_test:
|
||||
mnist_spec['stop']['training_iteration'] = 2
|
||||
|
||||
ray.init()
|
||||
|
|
|
@ -207,7 +207,7 @@ class HyperBandScheduler(FIFOScheduler):
|
|||
"""Cleans up trial info from bracket if trial errored early."""
|
||||
self.on_trial_remove(trial_runner, trial)
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, *args):
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Fair scheduling within iteration by completion percentage.
|
||||
|
||||
List of trials not used since all trials are tracked as state
|
||||
|
|
|
@ -63,7 +63,6 @@ class UnifiedLogger(Logger):
|
|||
print("TF not installed - cannot log with {}...".format(cls))
|
||||
continue
|
||||
self._loggers.append(cls(self.config, self.logdir, self.uri))
|
||||
print("Unified logger created with logdir '{}'".format(self.logdir))
|
||||
|
||||
def on_result(self, result):
|
||||
for logger in self._loggers:
|
||||
|
|
|
@ -31,7 +31,7 @@ class MedianStoppingRule(FIFOScheduler):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
grace_period=60.0, min_samples_required=3, hard_stop=True):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._stopped_trials = set()
|
||||
|
|
|
@ -2,189 +2,269 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
import math
|
||||
import copy
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.variant_generator import _format_vars
|
||||
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
# the bottom PBT_QUANTILE fraction.
|
||||
PBT_QUANTILE = 0.25
|
||||
|
||||
|
||||
class PBTTrialState(object):
|
||||
"""Internal PBT state tracked per-trial."""
|
||||
|
||||
def __init__(self, trial):
|
||||
self.orig_tag = trial.experiment_tag
|
||||
self.last_score = None
|
||||
self.last_checkpoint = None
|
||||
self.last_perturbation_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
return str((
|
||||
self.last_score, self.last_checkpoint,
|
||||
self.last_perturbation_time))
|
||||
|
||||
|
||||
def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
"""Return a config perturbed as specified.
|
||||
|
||||
Args:
|
||||
config (dict): Original hyperparameter configuration.
|
||||
mutations (dict): Specification of mutations to perform as documented
|
||||
in the PopulationBasedTraining scheduler.
|
||||
resample_probability (float): Probability of allowing resampling of a
|
||||
particular variable.
|
||||
custom_explore_fn (func): Custom explore fn applied after built-in
|
||||
config perturbations are.
|
||||
"""
|
||||
new_config = copy.deepcopy(config)
|
||||
for key, distribution in mutations.items():
|
||||
if isinstance(distribution, list):
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = random.choice(distribution)
|
||||
else:
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = distribution(config)
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = config[key] * 1.2
|
||||
else:
|
||||
new_config[key] = config[key] * 0.8
|
||||
if type(config[key]) is int:
|
||||
new_config[key] = int(new_config[key])
|
||||
if custom_explore_fn:
|
||||
new_config = custom_explore_fn(new_config)
|
||||
assert new_config is not None, \
|
||||
"Custom explore fn failed to return new config"
|
||||
print(
|
||||
"[explore] perturbed config from {} -> {}".format(config, new_config))
|
||||
return new_config
|
||||
|
||||
|
||||
def make_experiment_tag(orig_tag, config, mutations):
|
||||
"""Appends perturbed params to the trial name to show in the console."""
|
||||
|
||||
resolved_vars = {}
|
||||
for k in mutations.keys():
|
||||
resolved_vars[("config", k)] = config[k]
|
||||
return "{}@perturbed[{}]".format(orig_tag, _format_vars(resolved_vars))
|
||||
|
||||
|
||||
class PopulationBasedTraining(FIFOScheduler):
|
||||
"""Implements the Population Based Training algorithm as described in the
|
||||
PBT paper (https://arxiv.org/abs/1711.09846)(Experimental):
|
||||
"""Implements the Population Based Training (PBT) algorithm.
|
||||
|
||||
https://deepmind.com/blog/population-based-training-neural-networks
|
||||
|
||||
PBT trains a group of models (or agents) in parallel. Periodically, poorly
|
||||
performing models clone the state of the top performers, and a random
|
||||
mutation is applied to their hyperparameters in the hopes of
|
||||
outperforming the current top models.
|
||||
|
||||
Unlike other hyperparameter search algorithms, PBT mutates hyperparameters
|
||||
during training time. This enables very fast hyperparameter discovery and
|
||||
also automatically discovers good annealing schedules.
|
||||
|
||||
This Ray Tune PBT implementation considers all trials added as part of the
|
||||
PBT population. If the number of trials exceeds the cluster capacity,
|
||||
they will be time-multiplexed as to balance training progress across the
|
||||
population.
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for documenting length
|
||||
of time since last ready() call. Attribute only has to increase
|
||||
monotonically.
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with 'time_attr'. this may refer to any objective value that
|
||||
is supposed to increase with time.
|
||||
grace_period (float): Period of time, in which algorithm will not
|
||||
compare model to other models.
|
||||
perturbation_interval (float): Used in the truncation ready function to
|
||||
determine if enough time has passed so that a agent can be tested
|
||||
for readiness.
|
||||
hyperparameter_mutations (dict); Possible values that each
|
||||
hyperparameter can mutate to, as certain hyperparameters
|
||||
only work with certain values.
|
||||
with `time_attr`, this may refer to any objective value. Stopping
|
||||
procedures will use this attribute.
|
||||
perturbation_interval (float): Models will be considered for
|
||||
perturbation at this interval of `time_attr`. Note that
|
||||
perturbation incurs checkpoint overhead, so you shouldn't set this
|
||||
to be too frequent.
|
||||
hyperparam_mutations (dict): Hyperparams to mutate. The format is
|
||||
as follows: for each key, either a list or function can be
|
||||
provided. A list specifies values for a discrete parameter.
|
||||
A function specifies the distribution of a continuous parameter.
|
||||
You must specify at least one of `hyperparam_mutations` or
|
||||
`custom_explore_fn`.
|
||||
resample_probability (float): The probability of resampling from the
|
||||
original distribution when applying `hyperparam_mutations`. If not
|
||||
resampled, the value will be perturbed by a factor of 1.2 or 0.8
|
||||
if continuous, or left unchanged if discrete.
|
||||
custom_explore_fn (func): You can also specify a custom exploration
|
||||
function. This function is invoked as `f(config)` after built-in
|
||||
perturbations from `hyperparam_mutations` are applied, and should
|
||||
return `config` updated as needed. You must specify at least one of
|
||||
`hyperparam_mutations` or `custom_explore_fn`.
|
||||
|
||||
Example:
|
||||
>>> pbt = PopulationBasedTraining(
|
||||
>>> time_attr="training_iteration",
|
||||
>>> reward_attr="episode_reward_mean",
|
||||
>>> perturbation_interval=10, # every 10 `time_attr` units
|
||||
>>> # (training_iterations in this case)
|
||||
>>> hyperparam_mutations={
|
||||
>>> # Allow for scaling-based perturbations, with a uniform
|
||||
>>> # backing distribution for resampling.
|
||||
>>> "factor_1": lambda config: random.uniform(0.0, 20.0),
|
||||
>>> # Only allows resampling from this list as a perturbation.
|
||||
>>> "factor_2": [1, 2],
|
||||
>>> })
|
||||
>>> run_experiments({...}, scheduler=pbt)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean',
|
||||
grace_period=10.0, perturbation_interval=6.0,
|
||||
hyperparameter_mutations=None):
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=60.0, hyperparam_mutations={},
|
||||
resample_probability=0.25, custom_explore_fn=None):
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
"You must specify at least one of `hyperparam_mutations` or "
|
||||
"`custom_explore_fn` to use PBT.")
|
||||
FIFOScheduler.__init__(self)
|
||||
self._completed_trials = set()
|
||||
self._results = collections.defaultdict(list)
|
||||
self._last_perturbation_time = {}
|
||||
self._grace_period = grace_period
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
|
||||
self._hyperparameter_mutations = hyperparameter_mutations
|
||||
self._perturbation_interval = perturbation_interval
|
||||
self._checkpoint_paths = {}
|
||||
self._hyperparam_mutations = hyperparam_mutations
|
||||
self._resample_probability = resample_probability
|
||||
self._trial_state = {}
|
||||
self._custom_explore_fn = custom_explore_fn
|
||||
|
||||
# Metrics
|
||||
self._num_checkpoints = 0
|
||||
self._num_perturbations = 0
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
self._trial_state[trial] = PBTTrialState(trial)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
|
||||
self._results[trial].append(result)
|
||||
time = getattr(result, self._time_attr)
|
||||
# check model is ready to undergo mutation, based on user
|
||||
# function or default function
|
||||
self._checkpoint_paths[trial] = trial.checkpoint()
|
||||
if time > self._grace_period:
|
||||
ready = self._truncation_ready(result, trial, time)
|
||||
else:
|
||||
ready = False
|
||||
if ready:
|
||||
print("ready to undergo mutation")
|
||||
print("----")
|
||||
print("Current Trial is: {0}".format(trial))
|
||||
# get best trial for current time
|
||||
best_trial = self._get_best_trial(result, time)
|
||||
print("Best Trial is: {0}".format(best_trial))
|
||||
print(best_trial.config)
|
||||
state = self._trial_state[trial]
|
||||
|
||||
if time - state.last_perturbation_time < self._perturbation_interval:
|
||||
return TrialScheduler.CONTINUE # avoid checkpoint overhead
|
||||
|
||||
score = getattr(result, self._reward_attr)
|
||||
state.last_score = score
|
||||
state.last_perturbation_time = time
|
||||
lower_quantile, upper_quantile = self._quantiles()
|
||||
|
||||
if trial in upper_quantile:
|
||||
state.last_checkpoint = trial.checkpoint(to_object_store=True)
|
||||
self._num_checkpoints += 1
|
||||
else:
|
||||
state.last_checkpoint = None # not a top trial
|
||||
|
||||
if trial in lower_quantile:
|
||||
trial_to_clone = random.choice(upper_quantile)
|
||||
assert trial is not trial_to_clone
|
||||
self._exploit(trial, trial_to_clone)
|
||||
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED]:
|
||||
return TrialScheduler.PAUSE # yield time to other trials
|
||||
|
||||
# if current trial is the best trial (as in same hyperparameters),
|
||||
# do nothing
|
||||
if trial.config == best_trial.config:
|
||||
print("current trial is best trial")
|
||||
return TrialScheduler.CONTINUE
|
||||
else:
|
||||
self._exploit(self._hyperparameter_mutations, best_trial,
|
||||
trial, trial_runner, time)
|
||||
return TrialScheduler.CONTINUE
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self._results[trial].append(result)
|
||||
self._completed_trials.add(trial)
|
||||
def _exploit(self, trial, trial_to_clone):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial."""
|
||||
|
||||
def _exploit(self, hyperparameter_mutations, best_trial,
|
||||
trial, trial_runner, time):
|
||||
trial.stop()
|
||||
mutate_string = "_mutated@" + str(time)
|
||||
hyperparams = copy.deepcopy(best_trial.config)
|
||||
hyperparams = self._explore(hyperparams, hyperparameter_mutations,
|
||||
best_trial)
|
||||
print("new hyperparameter configuration: {0}".format(hyperparams))
|
||||
checkpoint = self._checkpoint_paths[best_trial]
|
||||
trial._checkpoint_path = checkpoint
|
||||
trial.config = hyperparams
|
||||
trial.experiment_tag = trial.experiment_tag + mutate_string
|
||||
trial.start()
|
||||
trial_state = self._trial_state[trial]
|
||||
new_state = self._trial_state[trial_to_clone]
|
||||
if not new_state.last_checkpoint:
|
||||
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
|
||||
return
|
||||
new_config = explore(
|
||||
trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability, self._custom_explore_fn)
|
||||
print(
|
||||
"[exploit] transferring weights from trial "
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
# TODO(ekl) restarting the trial is expensive. We should implement a
|
||||
# lighter way reset() method that can alter the trial config.
|
||||
trial.stop(stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = make_experiment_tag(
|
||||
trial_state.orig_tag, new_config, self._hyperparam_mutations)
|
||||
trial.start(new_state.last_checkpoint)
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||
|
||||
def _explore(self, hyperparams, hyperparameter_mutations, best_trial):
|
||||
if hyperparameter_mutations is not None:
|
||||
hyperparams = {
|
||||
param: random.choice(hyperparameter_mutations[param])
|
||||
for param in hyperparams
|
||||
if param != "env" and param in hyperparameter_mutations
|
||||
}
|
||||
for param in best_trial.config:
|
||||
if param not in hyperparameter_mutations and param != "env":
|
||||
hyperparams[param] = math.ceil(
|
||||
(best_trial.config[param]
|
||||
* random.choice([0.8, 1.2])/2.)) * 2
|
||||
def _quantiles(self):
|
||||
"""Returns trials in the lower and upper `quantile` of the population.
|
||||
|
||||
If there is not enough data to compute this, returns empty lists."""
|
||||
|
||||
trials = []
|
||||
for trial, state in self._trial_state.items():
|
||||
if state.last_score is not None and not trial.is_finished():
|
||||
trials.append(trial)
|
||||
trials.sort(key=lambda t: self._trial_state[t].last_score)
|
||||
|
||||
if len(trials) <= 1:
|
||||
return [], []
|
||||
else:
|
||||
hyperparams = {
|
||||
param: math.ceil(
|
||||
(random.choice([0.8, 1.2]) *
|
||||
hyperparams[param])/2.) * 2
|
||||
for param in hyperparams
|
||||
if param != "env"
|
||||
}
|
||||
hyperparams["env"] = best_trial.config["env"]
|
||||
return hyperparams
|
||||
return (
|
||||
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
|
||||
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
|
||||
|
||||
def _truncation_ready(self, result, trial, time):
|
||||
# function checks if appropriate time has passed
|
||||
# and trial is in the bottom 20% of all trials, and if so, is ready
|
||||
if trial not in self._last_perturbation_time:
|
||||
print("added trial to time tracker")
|
||||
self._last_perturbation_time[trial] = (time)
|
||||
else:
|
||||
time_since_last = time - self._last_perturbation_time[trial]
|
||||
if time_since_last >= self._perturbation_interval:
|
||||
self._last_perturbation_time[trial] = time
|
||||
sorted_result_keys = sorted(
|
||||
self._results, key=lambda x:
|
||||
max(self._results.get(x) if self._results.get(x) else [0])
|
||||
)
|
||||
max_index = int(round(len(sorted_result_keys) * 0.2))
|
||||
for i in range(0, max_index):
|
||||
if trial == sorted_result_keys[i]:
|
||||
print("{0} is in the bottomn 20 percent of {1}, \
|
||||
truncation is ready".format(
|
||||
trial,
|
||||
[x.experiment_tag for x in sorted_result_keys]
|
||||
))
|
||||
return True
|
||||
print("{0} is not in the bottomn 20 percent of {1}, \
|
||||
truncation is not ready".format(
|
||||
trial,
|
||||
[x.experiment_tag for x in sorted_result_keys]
|
||||
))
|
||||
else:
|
||||
print("not enough time has passed since last mutation")
|
||||
return False
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Ensures all trials get fair share of time (as defined by time_attr).
|
||||
|
||||
def _get_best_trial(self, result, time):
|
||||
results_at_time = {}
|
||||
for trial in self._results:
|
||||
results_at_time[trial] = [
|
||||
getattr(r, self._reward_attr)
|
||||
for r in self._results[trial]
|
||||
if getattr(r, self._time_attr) <= time
|
||||
]
|
||||
print("Results at {0}: {1}".format(time, results_at_time))
|
||||
return max(results_at_time, key=lambda x:
|
||||
max(results_at_time.get(x)
|
||||
if results_at_time.get(x) else [0]))
|
||||
This enables the PBT scheduler to support a greater number of
|
||||
concurrent trials than can fit in the cluster at any given time.
|
||||
"""
|
||||
|
||||
def _is_empty(self, x):
|
||||
if x:
|
||||
return False
|
||||
return True
|
||||
candidates = []
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
|
||||
trial_runner.has_resources(trial.resources):
|
||||
candidates.append(trial)
|
||||
candidates.sort(
|
||||
key=lambda trial: self._trial_state[trial].last_perturbation_time)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
def reset_stats(self):
|
||||
self._num_perturbations = 0
|
||||
self._num_checkpoints = 0
|
||||
|
||||
def last_scores(self, trials):
|
||||
scores = []
|
||||
for trial in trials:
|
||||
state = self._trial_state[trial]
|
||||
if state.last_score is not None and not trial.is_finished():
|
||||
scores.append(state.last_score)
|
||||
return scores
|
||||
|
||||
def debug_string(self):
|
||||
|
||||
min_time = 0
|
||||
best_trial = None
|
||||
for trial in self._completed_trials:
|
||||
last_result = self._results[trial][-1]
|
||||
if (getattr(last_result, self._time_attr)
|
||||
< min_time or min_time == 0):
|
||||
min_time = getattr(last_result, self._time_attr)
|
||||
best_trial = trial
|
||||
if best_trial is not None:
|
||||
return ("The Best Trial is currently {0} finishing in {1} iterations, \
|
||||
with the hyperparameters of {2}".format(
|
||||
best_trial, min_time, best_trial.config
|
||||
)
|
||||
)
|
||||
else:
|
||||
return "PBT has started"
|
||||
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
|
||||
self._num_checkpoints, self._num_perturbations)
|
||||
|
|
|
@ -84,10 +84,14 @@ TrainingResult = namedtuple("TrainingResult", [
|
|||
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
"hostname",
|
||||
|
||||
# (Auto=filled) The current hyperparameter configuration.
|
||||
"config",
|
||||
])
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
result = result._replace(config=None) # drop config from pretty print
|
||||
out = {}
|
||||
for k, v in result._asdict().items():
|
||||
if v is not None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from ray.tune import Trainable, TuneError
|
|||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import Trial, Resources, MAX_LEN_IDENTIFIER
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import generate_trials, grid_search, \
|
||||
RecursiveDependencyError
|
||||
|
@ -364,8 +364,9 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
self.assertLessEqual(
|
||||
len(str(trial)), MAX_LEN_IDENTIFIER)
|
||||
trial.start()
|
||||
self.assertLessEqual(len(trial.logdir), 200)
|
||||
trial.stop()
|
||||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
|
@ -5,15 +5,17 @@ from __future__ import print_function
|
|||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import TrialScheduler
|
||||
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return TrainingResult(time_total_s=t,
|
||||
|
@ -145,6 +147,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
class _MockTrialRunner():
|
||||
def __init__(self, scheduler):
|
||||
self._scheduler_alg = scheduler
|
||||
self.trials = []
|
||||
|
||||
def process_action(self, trial, action):
|
||||
if action == TrialScheduler.CONTINUE:
|
||||
|
@ -163,6 +166,13 @@ class _MockTrialRunner():
|
|||
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
|
||||
|
||||
def add_trial(self, trial):
|
||||
self.trials.append(trial)
|
||||
self._scheduler_alg.on_trial_add(self, trial)
|
||||
|
||||
def get_trials(self):
|
||||
return self.trials
|
||||
|
||||
def has_resources(self, resources):
|
||||
return True
|
||||
|
||||
|
@ -511,109 +521,202 @@ class HyperbandSuite(unittest.TestCase):
|
|||
self.assertFalse(trial in bracket._live_trials)
|
||||
|
||||
|
||||
class _MockTrialRunnerPBT(_MockTrialRunner):
|
||||
|
||||
def __init__(self):
|
||||
self._trials = []
|
||||
|
||||
def _launch_trial(self, trial):
|
||||
trial.status = Trial.RUNNING
|
||||
self._trials.append(trial)
|
||||
|
||||
|
||||
class _MockTrialPBT(Trial):
|
||||
class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
self.config = config
|
||||
self.experiment_tag = "tag"
|
||||
self.logger_running = False
|
||||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
return 'checkpointed'
|
||||
return self.trainable_name
|
||||
|
||||
def start(self):
|
||||
return 'started'
|
||||
def start(self, checkpoint=None):
|
||||
self.logger_running = True
|
||||
self.restored_checkpoint = checkpoint
|
||||
|
||||
def stop(self):
|
||||
return 'stopped'
|
||||
def stop(self, stop_logger=False):
|
||||
if stop_logger:
|
||||
self.logger_running = False
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
def schedulerSetup(self, num_trials):
|
||||
sched = PopulationBasedTraining()
|
||||
runner = _MockTrialRunnerPBT()
|
||||
for i in range(num_trials):
|
||||
t = _MockTrialPBT("__parameter_tuning")
|
||||
t.config = {'test': 1, 'test1': 1, 'env': 'test'}
|
||||
t.experiment_tag = str(i)
|
||||
runner._launch_trial(t)
|
||||
return sched, runner
|
||||
def basicSetup(self, resample_prob=0.0, explore=None):
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=10,
|
||||
resample_probability=resample_prob,
|
||||
hyperparam_mutations={
|
||||
"id_factor": [100],
|
||||
"float_factor": lambda c: 100.0,
|
||||
"int_factor": lambda c: 10,
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
trial = _MockTrial(
|
||||
i,
|
||||
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
|
||||
"int_factor": 10})
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
|
||||
TrialScheduler.CONTINUE)
|
||||
pbt.reset_stats()
|
||||
return pbt, runner
|
||||
|
||||
def testReadyFunction(self):
|
||||
sched, runner = self.schedulerSetup(5)
|
||||
# different time intervals to test at
|
||||
best_result_early = result(18, 100)
|
||||
best_result_late = result(25, 100)
|
||||
runner._trials[0].config = {'test': 10, 'test1': 10, 'env': 'test'}
|
||||
# setting up best trial so that it consistently is the best trial
|
||||
sched.on_trial_result(runner, runner._trials[0], result(11, 0))
|
||||
sched.on_trial_result(runner, runner._trials[0], result(14, 2))
|
||||
sched.on_trial_result(runner, runner._trials[0], best_result_early)
|
||||
sched.on_trial_result(runner, runner._trials[0], best_result_late)
|
||||
# testing that adding trials to time tracker works, and that
|
||||
# ready function knows when to start
|
||||
for trial in runner._trials[1:]:
|
||||
old_config = trial.config
|
||||
sched.on_trial_result(
|
||||
runner, trial, result(11, random.randint(0, 10)))
|
||||
self.assertTrue(old_config == trial.config)
|
||||
# making sure that the second trial in runner._trials
|
||||
# (not the best trial) is the worst trial
|
||||
for trial in runner._trials[2:]:
|
||||
# testing to see that ready function knows
|
||||
# that not enough time has passed
|
||||
sched.on_trial_result(
|
||||
runner, trial, result(16, random.randint(40, 50)))
|
||||
# testing to see if worst trial (aka bottom 20%)
|
||||
# has mutated (ready function initiated)
|
||||
old_config = runner._trials[1].config
|
||||
sched.on_trial_result(runner, runner._trials[1], result(26, 30))
|
||||
self.assertFalse(old_config == runner._trials[1].config)
|
||||
def testCheckpointsMostPromisingTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
def testExploitExploreFunction(self):
|
||||
sched, runner = self.schedulerSetup(5)
|
||||
# different time intervals to test at
|
||||
best_result_early = result(18, 100)
|
||||
best_result_late = result(25, 100)
|
||||
runner._trials[0].config = {'test': 10, 'test1': 10, 'env': 'test'}
|
||||
# setting up best trial so that it consistently is the best trial
|
||||
sched.on_trial_result(runner, runner._trials[0], best_result_early)
|
||||
sched.on_trial_result(runner, runner._trials[0], best_result_late)
|
||||
# testing that adding trials to time tracker works, and
|
||||
# that ready function knows when to start
|
||||
for trial in runner._trials[1:]:
|
||||
sched.on_trial_result(
|
||||
runner, trial, result(11, random.randint(0, 10)))
|
||||
# making sure that the second trial in runner._trials
|
||||
# (not the best trial) is the worst trial
|
||||
for trial in runner._trials[2:]:
|
||||
sched.on_trial_result(
|
||||
runner, trial, result(16, random.randint(40, 50)))
|
||||
sched.on_trial_result(runner, runner._trials[1], result(26, 30))
|
||||
# make sure mutated values are multiples of 0.8 and 1.2
|
||||
# (default explore values)
|
||||
for key in runner._trials[0].config:
|
||||
if key == 'env':
|
||||
continue
|
||||
else:
|
||||
if (
|
||||
runner._trials[1].config[key] == 0.8 *
|
||||
runner._trials[0].config[key] or
|
||||
runner._trials[1].config[key] == 1.2 *
|
||||
runner._trials[0].config[key]
|
||||
):
|
||||
continue
|
||||
else:
|
||||
raise ValueError('Trial not correctly explored (mutated)')
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# checkpoint: both past interval and upper quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(30, 201)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
# not upper quantile any more
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[4], result(30, 199)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
def testPerturbsLowPerformingTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
# no perturbation: haven't hit next perturbation interval
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
# perturb since it's lower quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" in trials[0].experiment_tag)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
|
||||
# also perturbed
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertTrue("@perturbed" in trials[2].experiment_tag)
|
||||
|
||||
def testPerturbWithoutResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertIn(trials[0].config["id_factor"], [3, 4])
|
||||
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
|
||||
self.assertEqual(type(trials[0].config["float_factor"]), float)
|
||||
self.assertIn(trials[0].config["int_factor"], [8, 12])
|
||||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testPerturbWithResample(self):
|
||||
pbt, runner = self.basicSetup(resample_prob=1.0)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(trials[0].config["id_factor"], 100)
|
||||
self.assertEqual(trials[0].config["float_factor"], 100.0)
|
||||
self.assertEqual(type(trials[0].config["float_factor"]), float)
|
||||
self.assertEqual(trials[0].config["int_factor"], 10)
|
||||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testYieldsTimeToOtherTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
trials[0].status = Trial.PENDING # simulate not enough resources
|
||||
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
|
||||
|
||||
def testSchedulesMostBehindTrialToRun(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
pbt.on_trial_result(runner, trials[0], result(800, 1000))
|
||||
pbt.on_trial_result(runner, trials[1], result(700, 1001))
|
||||
pbt.on_trial_result(runner, trials[2], result(600, 1002))
|
||||
pbt.on_trial_result(runner, trials[3], result(500, 1003))
|
||||
pbt.on_trial_result(runner, trials[4], result(700, 1004))
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), None)
|
||||
for i in range(5):
|
||||
trials[i].status = Trial.PENDING
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
|
||||
|
||||
def testPerturbationResetsLastPerturbTime(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
pbt.on_trial_result(runner, trials[0], result(10000, 1005))
|
||||
pbt.on_trial_result(runner, trials[1], result(10000, 1004))
|
||||
pbt.on_trial_result(runner, trials[2], result(600, 1003))
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
pbt.on_trial_result(runner, trials[3], result(500, 1002))
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
pbt.on_trial_result(runner, trials[3], result(600, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
pbt.on_trial_result(runner, trials[3], result(11000, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
|
||||
def testPostprocessingHook(self):
|
||||
def explore(new_config):
|
||||
new_config["id_factor"] = 42
|
||||
new_config["float_factor"] = 43
|
||||
return new_config
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(trials[0].config["id_factor"], 42)
|
||||
self.assertEqual(trials[0].config["float_factor"], 43)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
unittest.main(verbosity=2)
|
|
@ -135,7 +135,8 @@ class Trainable(object):
|
|||
time_total_s=self._time_total,
|
||||
neg_mean_loss=neg_loss,
|
||||
pid=os.getpid(),
|
||||
hostname=os.uname()[1])
|
||||
hostname=os.uname()[1],
|
||||
config=self.config)
|
||||
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
|
@ -185,8 +186,8 @@ class Trainable(object):
|
|||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
})
|
||||
print("Saving checkpoint to object store, {} bytes".format(
|
||||
len(compressed)))
|
||||
if len(compressed) > 10e6: # getting pretty large
|
||||
print("Checkpoint size is {} bytes".format(len(compressed)))
|
||||
f.write(compressed)
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
|
|
|
@ -96,7 +96,7 @@ class Trial(object):
|
|||
"Stopping condition key `{}` must be one of {}".format(
|
||||
k, TrainingResult._fields))
|
||||
|
||||
# Immutable config
|
||||
# Trial config
|
||||
self.trainable_name = trainable_name
|
||||
self.config = config or {}
|
||||
self.local_dir = local_dir
|
||||
|
@ -105,6 +105,7 @@ class Trial(object):
|
|||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.upload_dir = upload_dir
|
||||
self.verbose = True
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
|
@ -117,16 +118,22 @@ class Trial(object):
|
|||
self.result_logger = None
|
||||
self.last_debug = 0
|
||||
self.trial_id = binary_to_hex(random_string())[:8]
|
||||
self.error_file = None
|
||||
|
||||
def start(self):
|
||||
def start(self, checkpoint_obj=None):
|
||||
"""Starts this trial.
|
||||
|
||||
If an error is encountered when starting the trial, an exception will
|
||||
be thrown.
|
||||
|
||||
Args:
|
||||
checkpoint_obj (obj): Optional checkpoint to resume from.
|
||||
"""
|
||||
|
||||
self._setup_runner()
|
||||
if self._checkpoint_path:
|
||||
if checkpoint_obj:
|
||||
self.restore_from_obj(checkpoint_obj)
|
||||
elif self._checkpoint_path:
|
||||
self.restore_from_path(self._checkpoint_path)
|
||||
elif self._checkpoint_obj:
|
||||
self.restore_from_obj(self._checkpoint_obj)
|
||||
|
@ -155,6 +162,7 @@ class Trial(object):
|
|||
self.logdir, "error_{}.txt".format(date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
self.error_file = error_file
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
|
@ -163,9 +171,6 @@ class Trial(object):
|
|||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
if unfinished:
|
||||
print(("Stopping %s Actor timed out, "
|
||||
"but moving on...") % self)
|
||||
except Exception:
|
||||
print("Error stopping runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
@ -230,7 +235,7 @@ class Trial(object):
|
|||
"""Returns a progress message for printing out to the console."""
|
||||
|
||||
if self.last_result is None:
|
||||
return self.status
|
||||
return self._status_string()
|
||||
|
||||
def location_string(hostname, pid):
|
||||
if hostname == os.uname()[1]:
|
||||
|
@ -240,7 +245,8 @@ class Trial(object):
|
|||
|
||||
pieces = [
|
||||
'{} [{}]'.format(
|
||||
self.status, location_string(
|
||||
self._status_string(),
|
||||
location_string(
|
||||
self.last_result.hostname, self.last_result.pid)),
|
||||
'{} s'.format(int(self.last_result.time_total_s)),
|
||||
'{} ts'.format(int(self.last_result.timesteps_total))]
|
||||
|
@ -259,6 +265,11 @@ class Trial(object):
|
|||
|
||||
return ', '.join(pieces)
|
||||
|
||||
def _status_string(self):
|
||||
return "{}{}".format(
|
||||
self.status,
|
||||
" => {}".format(self.error_file) if self.error_file else "")
|
||||
|
||||
def checkpoint(self, to_object_store=False):
|
||||
"""Checkpoints the state of this trial.
|
||||
|
||||
|
@ -276,7 +287,8 @@ class Trial(object):
|
|||
self._checkpoint_path = path
|
||||
self._checkpoint_obj = obj
|
||||
|
||||
print("Saved checkpoint to:", path or obj)
|
||||
if self.verbose:
|
||||
print("Saved checkpoint for {} to {}".format(self, path or obj))
|
||||
return path or obj
|
||||
|
||||
def restore_from_path(self, path):
|
||||
|
@ -310,7 +322,9 @@ class Trial(object):
|
|||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
|
||||
if terminate or (
|
||||
self.verbose and
|
||||
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
|
@ -348,12 +362,17 @@ class Trial(object):
|
|||
config=self.config, registry=get_registry(),
|
||||
logger_creator=logger_creator)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
|
||||
def set_verbose(self, verbose):
|
||||
self.verbose = verbose
|
||||
|
||||
Truncates to MAX_LEN_IDENTIFIER (default is 130) to avoid problems
|
||||
when creating logging directories.
|
||||
"""
|
||||
def is_finished(self):
|
||||
return self.status in [Trial.TERMINATED, Trial.ERROR]
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(
|
||||
self.trainable_name, self.config["env"])
|
||||
|
@ -361,4 +380,4 @@ class Trial(object):
|
|||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier[:MAX_LEN_IDENTIFIER]
|
||||
return identifier
|
||||
|
|
|
@ -31,7 +31,7 @@ def _make_scheduler(args):
|
|||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT):
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
|
||||
# Make sure rllib agents are registered
|
||||
from ray import rllib # noqa # pylint: disable=unused-import
|
||||
|
@ -44,6 +44,7 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
|||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
trial.set_verbose(verbose)
|
||||
runner.add_trial(trial)
|
||||
print(runner.debug_string(max_debug=99999))
|
||||
|
||||
|
|
|
@ -164,7 +164,15 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/tune_mnist_ray.py \
|
||||
--fast
|
||||
--smoke-test
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/pbt_example.py \
|
||||
--smoke-test
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/hyperband_example.py \
|
||||
--smoke-test
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py
|
||||
|
|
Loading…
Add table
Reference in a new issue