[tune] clean up population based training prototype (#1478)

* patch up pbt

* Sat Jan 27 01:00:03 PST 2018

* Sat Jan 27 01:04:14 PST 2018

* Sat Jan 27 01:04:21 PST 2018

* Sat Jan 27 01:15:15 PST 2018

* Sat Jan 27 01:15:42 PST 2018

* Sat Jan 27 01:16:14 PST 2018

* Sat Jan 27 01:38:42 PST 2018

* Sat Jan 27 01:39:21 PST 2018

* add pbt

* Sat Jan 27 01:41:19 PST 2018

* Sat Jan 27 01:44:21 PST 2018

* Sat Jan 27 01:45:46 PST 2018

* Sat Jan 27 16:54:42 PST 2018

* Sat Jan 27 16:57:53 PST 2018

* clean up test

* Sat Jan 27 18:01:15 PST 2018

* Sat Jan 27 18:02:54 PST 2018

* Sat Jan 27 18:11:18 PST 2018

* Sat Jan 27 18:11:55 PST 2018

* Sat Jan 27 18:14:09 PST 2018

* review

* try out a ppo example

* some tweaks to ppo example

* add postprocess hook

* Sun Jan 28 15:00:40 PST 2018

* clean up custom explore fn

* Sun Jan 28 15:10:21 PST 2018

* Sun Jan 28 15:14:53 PST 2018

* Sun Jan 28 15:17:04 PST 2018

* Sun Jan 28 15:33:13 PST 2018

* Sun Jan 28 15:56:40 PST 2018

* Sun Jan 28 15:57:36 PST 2018

* Sun Jan 28 16:00:35 PST 2018

* Sun Jan 28 16:02:58 PST 2018

* Sun Jan 28 16:29:50 PST 2018

* Sun Jan 28 16:30:36 PST 2018

* Sun Jan 28 16:31:44 PST 2018

* improve tune doc

* concepts

* update humanoid

* Fri Feb  2 18:03:33 PST 2018

* fix example

* show error file
This commit is contained in:
Eric Liang 2018-02-02 23:03:12 -08:00 committed by GitHub
parent a936468f99
commit b948405532
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 698 additions and 288 deletions

View file

@ -128,13 +128,18 @@ script:
- python test/multi_node_test.py
- python test/recursion_test.py
- python test/monitor_test.py
- python test/trial_runner_test.py
- python test/trial_scheduler_test.py
- python test/tune_server_test.py
- python test/cython_test.py
# ray dataframe tests
- python -m pytest python/ray/dataframe/test/test_dataframe.py
- python -m pytest python/ray/dataframe/test/test_series.py
# ray tune tests
- python -m pytest python/ray/tune/test/trial_runner_test.py
- python -m pytest python/ray/tune/test/trial_scheduler_test.py
- python -m pytest python/ray/tune/test/tune_server_test.py
# ray rllib tests
- python -m pytest python/ray/rllib/test/test_catalog.py
- python -m pytest python/ray/rllib/test/test_filters.py
- python -m pytest python/ray/rllib/test/test_optimizers.py

BIN
doc/source/pbt.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

View file

@ -290,6 +290,9 @@ in the ``config`` section of the experiments.
ray.init()
run_experiments(experiment)
For an advanced example of using Population Based Training (PBT) with RLlib,
see the `PPO + PBT Walker2D training example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__.
Contributing to RLlib
---------------------

View file

@ -1,9 +1,11 @@
Ray Tune: Hyperparameter Optimization Framework
===============================================
This document describes Ray Tune, a hyperparameter tuning framework for long-running tasks such as RL and deep learning training. It has the following features:
This document describes Ray Tune, a hyperparameter tuning framework for long-running tasks such as RL and deep learning training. Ray Tune makes it easy to go from running one or more experiments on a single machine to running on a large cluster with efficient search algorithms.
- Early stopping algorithms such as `Median Stopping Rule <https://research.google.com/pubs/pub46180.html>`__ and `HyperBand <https://arxiv.org/abs/1603.06560>`__.
It has the following features:
- Scalable implementations of search algorithms such as `Population Based Training (PBT) <#population-based-training>`__, `Median Stopping Rule <https://research.google.com/pubs/pub46180.html>`__, and `HyperBand <https://arxiv.org/abs/1603.06560>`__.
- Integration with visualization tools such as `TensorBoard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`__, `rllab's VisKit <https://media.readthedocs.org/pdf/rllab/latest/rllab.pdf>`__, and a `parallel coordinates visualization <https://en.wikipedia.org/wiki/Parallel_coordinates>`__.
@ -11,8 +13,18 @@ This document describes Ray Tune, a hyperparameter tuning framework for long-run
- Resource-aware scheduling, including support for concurrent runs of algorithms that may themselves be parallel and distributed.
You can find the code for Ray Tune `here on GitHub <https://github.com/ray-project/ray/tree/master/python/ray/tune>`__.
Concepts
--------
Ray Tune schedules a number of *trials* in a cluster. Each trial runs a user-defined Python function or class and is parameterized by a json *config* variation passed to the user code.
Ray Tune provides a ``run_experiments(spec)`` function that generates and runs the trials described by the experiment specification. The trials are scheduled and managed by a *trial scheduler* that implements the search algorithm (default is FIFO).
Ray Tune can be used anywhere Ray can, e.g. on your laptop with ``ray.init()`` embedded in a Python script, or in an `auto-scaling cluster <autoscaling.html>`__ for massive parallelism.
Getting Started
---------------
@ -133,7 +145,7 @@ To reduce costs, long-running trials can often be early stopped if their initial
An example of this can be found in `hyperband_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__. The progress of one such HyperBand run is shown below.
Note that some trial schedulers such as HyperBand require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
Note that some trial schedulers such as HyperBand and PBT require your Trainable to support checkpointing, which is described in the next section. Checkpointing enables the scheduler to multiplex many concurrent trials onto a limited size cluster.
::
@ -172,10 +184,19 @@ Currently we support the following early stopping algorithms, or you can write y
.. autoclass:: ray.tune.median_stopping_rule.MedianStoppingRule
.. autoclass:: ray.tune.hyperband.HyperBandScheduler
Population Based Training
-------------------------
Ray Tune includes a distributed implementation of `Population Based Training (PBT) <https://deepmind.com/blog/population-based-training-neural-networks>`__. PBT also requires your Trainable to support checkpointing. You can run this `toy PBT example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py>`__ to get an idea of how how PBT operates. When training in PBT mode, the set of trial variations is treated as the population, so a single trial may see many different hyperparameters over its lifetime, which is recorded in the ``result.json`` file. The following figure generated by the example shows PBT discovering new hyperparams over the course of a single experiment:
.. image:: pbt.png
.. autoclass:: ray.tune.pbt.PopulationBasedTraining
Trial Checkpointing
-------------------
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand.
To enable checkpoint / resume, you must subclass ``Trainable`` and implement its ``_train``, ``_save``, and ``_restore`` abstract methods `(example) <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperband_example.py>`__: Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand and PBT.
.. autoclass:: ray.tune.trainable.Trainable

View file

@ -76,13 +76,13 @@ class PPOEvaluator(Evaluator):
# Value function predictions before the policy update.
self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))
assert config["sgd_batchsize"] % len(devices) == 0, \
"Batch size must be evenly divisible by devices"
if is_remote:
self.batch_size = config["rollout_batchsize"]
self.per_device_batch_size = config["rollout_batchsize"]
else:
self.batch_size = config["sgd_batchsize"]
self.batch_size = int(
config["sgd_batchsize"] / len(devices)) * len(devices)
assert self.batch_size % len(devices) == 0
self.per_device_batch_size = int(self.batch_size / len(devices))
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):

0
python/ray/rllib/test/test_checkpoint_restore.py Executable file → Normal file
View file

View file

@ -4,6 +4,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import random
@ -49,6 +50,10 @@ class MyTrainableClass(Trainable):
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
# Hyperband early stopping, configured with `episode_reward_mean` as the
@ -60,7 +65,8 @@ if __name__ == "__main__":
run_experiments({
"hyperband_test": {
"run": "my_class",
"repeat": 100,
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),

View file

@ -0,0 +1,88 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import random
import time
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
from ray.tune.pbt import PopulationBasedTraining
class MyTrainableClass(Trainable):
"""Fake agent whose learning rate is determined by dummy factors."""
def _setup(self):
self.timestep = 0
self.current_value = 0.0
def _train(self):
time.sleep(0.1)
# Reward increase is parabolic as a function of factor_2, with a
# maxima around factor_1=10.0.
self.current_value += max(
0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0))
# Flat increase by factor_2
self.current_value += random.gauss(self.config["factor_2"], 1.0)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(
episode_reward_mean=self.current_value, timesteps_this_iter=1)
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps(
{"timestep": self.timestep, "value": self.current_value}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
data = json.loads(f.read())
self.timestep = data["timestep"]
self.current_value = data["value"]
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration", reward_attr="episode_reward_mean",
perturbation_interval=10,
hyperparam_mutations={
# Allow for scaling-based perturbations, with a uniform backing
# distribution for resampling.
"factor_1": lambda config: random.uniform(0.0, 20.0),
# Only allows resampling from this list as a perturbation.
"factor_2": [1, 2],
})
# Try to find the best factor 1 and factor 2
run_experiments({
"pbt_test": {
"run": "my_class",
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
"repeat": 10,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,
},
}
}, scheduler=pbt, verbose=False)

View file

@ -0,0 +1,71 @@
#!/usr/bin/env python
"""Example of using PBT with RLlib.
Note that this requires a cluster with at least 8 GPUs in order for all trials
to run concurrently, otherwise PBT will round-robin train the trials which
is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import ray
from ray.tune import run_experiments
from ray.tune.pbt import PopulationBasedTraining
if __name__ == "__main__":
# Postprocess the perturbed config to ensure it's still valid
def explore(config):
# ensure we collect enough timesteps to do sgd
if config["timesteps_per_batch"] < config["sgd_batchsize"] * 2:
config["timesteps_per_batch"] = config["sgd_batchsize"] * 2
# ensure we run at least one sgd iter
if config["num_sgd_iter"] < 1:
config["num_sgd_iter"] = 1
return config
pbt = PopulationBasedTraining(
time_attr="time_total_s", reward_attr="episode_reward_mean",
perturbation_interval=120,
resample_probability=0.25,
# Specifies the resampling distributions of these hyperparams
hyperparam_mutations={
"lambda": lambda config: random.uniform(0.9, 1.0),
"clip_param": lambda config: random.uniform(0.01, 0.5),
"sgd_stepsize": lambda config: random.uniform(.00001, .001),
"num_sgd_iter": lambda config: random.randint(1, 30),
"sgd_batchsize": lambda config: random.randint(128, 16384),
"timesteps_per_batch":
lambda config: random.randint(2000, 160000),
},
custom_explore_fn=explore)
ray.init()
run_experiments({
"pbt_humanoid_test": {
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"resources": {"cpu": 4, "gpu": 1},
"config": {
"kl_coeff": 1.0,
"num_workers": 8,
"devices": ["/gpu:0"],
"model": {"free_log_std": True},
# These params are tuned from their starting value
"lambda": 0.95,
"clip_param": 0.2,
# Start off with several random variations
"sgd_stepsize": lambda spec: random.uniform(.00001, .001),
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
lambda spec: random.choice([10000, 20000, 40000])
},
},
}, scheduler=pbt)

View file

@ -205,7 +205,7 @@ def train(config={'activation': 'relu'}, reporter=None):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--fast', action='store_true', help='Finish quickly for testing')
'--smoke-test', action='store_true', help='Finish quickly for testing')
args, _ = parser.parse_known_args()
register_trainable('train_mnist', train)
@ -220,7 +220,7 @@ if __name__ == '__main__':
},
}
if args.fast:
if args.smoke_test:
mnist_spec['stop']['training_iteration'] = 2
ray.init()

View file

@ -207,7 +207,7 @@ class HyperBandScheduler(FIFOScheduler):
"""Cleans up trial info from bracket if trial errored early."""
self.on_trial_remove(trial_runner, trial)
def choose_trial_to_run(self, trial_runner, *args):
def choose_trial_to_run(self, trial_runner):
"""Fair scheduling within iteration by completion percentage.
List of trials not used since all trials are tracked as state

View file

@ -63,7 +63,6 @@ class UnifiedLogger(Logger):
print("TF not installed - cannot log with {}...".format(cls))
continue
self._loggers.append(cls(self.config, self.logdir, self.uri))
print("Unified logger created with logdir '{}'".format(self.logdir))
def on_result(self, result):
for logger in self._loggers:

View file

@ -31,7 +31,7 @@ class MedianStoppingRule(FIFOScheduler):
"""
def __init__(
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
grace_period=60.0, min_samples_required=3, hard_stop=True):
FIFOScheduler.__init__(self)
self._stopped_trials = set()

View file

@ -2,189 +2,269 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random
import math
import copy
from ray.tune.error import TuneError
from ray.tune.trial import Trial
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.variant_generator import _format_vars
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
# the bottom PBT_QUANTILE fraction.
PBT_QUANTILE = 0.25
class PBTTrialState(object):
"""Internal PBT state tracked per-trial."""
def __init__(self, trial):
self.orig_tag = trial.experiment_tag
self.last_score = None
self.last_checkpoint = None
self.last_perturbation_time = 0
def __repr__(self):
return str((
self.last_score, self.last_checkpoint,
self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
"""Return a config perturbed as specified.
Args:
config (dict): Original hyperparameter configuration.
mutations (dict): Specification of mutations to perform as documented
in the PopulationBasedTraining scheduler.
resample_probability (float): Probability of allowing resampling of a
particular variable.
custom_explore_fn (func): Custom explore fn applied after built-in
config perturbations are.
"""
new_config = copy.deepcopy(config)
for key, distribution in mutations.items():
if isinstance(distribution, list):
if random.random() < resample_probability:
new_config[key] = random.choice(distribution)
else:
if random.random() < resample_probability:
new_config[key] = distribution(config)
elif random.random() > 0.5:
new_config[key] = config[key] * 1.2
else:
new_config[key] = config[key] * 0.8
if type(config[key]) is int:
new_config[key] = int(new_config[key])
if custom_explore_fn:
new_config = custom_explore_fn(new_config)
assert new_config is not None, \
"Custom explore fn failed to return new config"
print(
"[explore] perturbed config from {} -> {}".format(config, new_config))
return new_config
def make_experiment_tag(orig_tag, config, mutations):
"""Appends perturbed params to the trial name to show in the console."""
resolved_vars = {}
for k in mutations.keys():
resolved_vars[("config", k)] = config[k]
return "{}@perturbed[{}]".format(orig_tag, _format_vars(resolved_vars))
class PopulationBasedTraining(FIFOScheduler):
"""Implements the Population Based Training algorithm as described in the
PBT paper (https://arxiv.org/abs/1711.09846)(Experimental):
"""Implements the Population Based Training (PBT) algorithm.
https://deepmind.com/blog/population-based-training-neural-networks
PBT trains a group of models (or agents) in parallel. Periodically, poorly
performing models clone the state of the top performers, and a random
mutation is applied to their hyperparameters in the hopes of
outperforming the current top models.
Unlike other hyperparameter search algorithms, PBT mutates hyperparameters
during training time. This enables very fast hyperparameter discovery and
also automatically discovers good annealing schedules.
This Ray Tune PBT implementation considers all trials added as part of the
PBT population. If the number of trials exceeds the cluster capacity,
they will be time-multiplexed as to balance training progress across the
population.
Args:
time_attr (str): The TrainingResult attr to use for documenting length
of time since last ready() call. Attribute only has to increase
monotonically.
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with 'time_attr'. this may refer to any objective value that
is supposed to increase with time.
grace_period (float): Period of time, in which algorithm will not
compare model to other models.
perturbation_interval (float): Used in the truncation ready function to
determine if enough time has passed so that a agent can be tested
for readiness.
hyperparameter_mutations (dict); Possible values that each
hyperparameter can mutate to, as certain hyperparameters
only work with certain values.
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
perturbation_interval (float): Models will be considered for
perturbation at this interval of `time_attr`. Note that
perturbation incurs checkpoint overhead, so you shouldn't set this
to be too frequent.
hyperparam_mutations (dict): Hyperparams to mutate. The format is
as follows: for each key, either a list or function can be
provided. A list specifies values for a discrete parameter.
A function specifies the distribution of a continuous parameter.
You must specify at least one of `hyperparam_mutations` or
`custom_explore_fn`.
resample_probability (float): The probability of resampling from the
original distribution when applying `hyperparam_mutations`. If not
resampled, the value will be perturbed by a factor of 1.2 or 0.8
if continuous, or left unchanged if discrete.
custom_explore_fn (func): You can also specify a custom exploration
function. This function is invoked as `f(config)` after built-in
perturbations from `hyperparam_mutations` are applied, and should
return `config` updated as needed. You must specify at least one of
`hyperparam_mutations` or `custom_explore_fn`.
Example:
>>> pbt = PopulationBasedTraining(
>>> time_attr="training_iteration",
>>> reward_attr="episode_reward_mean",
>>> perturbation_interval=10, # every 10 `time_attr` units
>>> # (training_iterations in this case)
>>> hyperparam_mutations={
>>> # Allow for scaling-based perturbations, with a uniform
>>> # backing distribution for resampling.
>>> "factor_1": lambda config: random.uniform(0.0, 20.0),
>>> # Only allows resampling from this list as a perturbation.
>>> "factor_2": [1, 2],
>>> })
>>> run_experiments({...}, scheduler=pbt)
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean',
grace_period=10.0, perturbation_interval=6.0,
hyperparameter_mutations=None):
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
perturbation_interval=60.0, hyperparam_mutations={},
resample_probability=0.25, custom_explore_fn=None):
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
"You must specify at least one of `hyperparam_mutations` or "
"`custom_explore_fn` to use PBT.")
FIFOScheduler.__init__(self)
self._completed_trials = set()
self._results = collections.defaultdict(list)
self._last_perturbation_time = {}
self._grace_period = grace_period
self._reward_attr = reward_attr
self._time_attr = time_attr
self._hyperparameter_mutations = hyperparameter_mutations
self._perturbation_interval = perturbation_interval
self._checkpoint_paths = {}
self._hyperparam_mutations = hyperparam_mutations
self._resample_probability = resample_probability
self._trial_state = {}
self._custom_explore_fn = custom_explore_fn
# Metrics
self._num_checkpoints = 0
self._num_perturbations = 0
def on_trial_add(self, trial_runner, trial):
self._trial_state[trial] = PBTTrialState(trial)
def on_trial_result(self, trial_runner, trial, result):
self._results[trial].append(result)
time = getattr(result, self._time_attr)
# check model is ready to undergo mutation, based on user
# function or default function
self._checkpoint_paths[trial] = trial.checkpoint()
if time > self._grace_period:
ready = self._truncation_ready(result, trial, time)
else:
ready = False
if ready:
print("ready to undergo mutation")
print("----")
print("Current Trial is: {0}".format(trial))
# get best trial for current time
best_trial = self._get_best_trial(result, time)
print("Best Trial is: {0}".format(best_trial))
print(best_trial.config)
state = self._trial_state[trial]
if time - state.last_perturbation_time < self._perturbation_interval:
return TrialScheduler.CONTINUE # avoid checkpoint overhead
score = getattr(result, self._reward_attr)
state.last_score = score
state.last_perturbation_time = time
lower_quantile, upper_quantile = self._quantiles()
if trial in upper_quantile:
state.last_checkpoint = trial.checkpoint(to_object_store=True)
self._num_checkpoints += 1
else:
state.last_checkpoint = None # not a top trial
if trial in lower_quantile:
trial_to_clone = random.choice(upper_quantile)
assert trial is not trial_to_clone
self._exploit(trial, trial_to_clone)
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED]:
return TrialScheduler.PAUSE # yield time to other trials
# if current trial is the best trial (as in same hyperparameters),
# do nothing
if trial.config == best_trial.config:
print("current trial is best trial")
return TrialScheduler.CONTINUE
else:
self._exploit(self._hyperparameter_mutations, best_trial,
trial, trial_runner, time)
return TrialScheduler.CONTINUE
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self._results[trial].append(result)
self._completed_trials.add(trial)
def _exploit(self, trial, trial_to_clone):
"""Transfers perturbed state from trial_to_clone -> trial."""
def _exploit(self, hyperparameter_mutations, best_trial,
trial, trial_runner, time):
trial.stop()
mutate_string = "_mutated@" + str(time)
hyperparams = copy.deepcopy(best_trial.config)
hyperparams = self._explore(hyperparams, hyperparameter_mutations,
best_trial)
print("new hyperparameter configuration: {0}".format(hyperparams))
checkpoint = self._checkpoint_paths[best_trial]
trial._checkpoint_path = checkpoint
trial.config = hyperparams
trial.experiment_tag = trial.experiment_tag + mutate_string
trial.start()
trial_state = self._trial_state[trial]
new_state = self._trial_state[trial_to_clone]
if not new_state.last_checkpoint:
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
return
new_config = explore(
trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability, self._custom_explore_fn)
print(
"[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
# TODO(ekl) restarting the trial is expensive. We should implement a
# lighter way reset() method that can alter the trial config.
trial.stop(stop_logger=False)
trial.config = new_config
trial.experiment_tag = make_experiment_tag(
trial_state.orig_tag, new_config, self._hyperparam_mutations)
trial.start(new_state.last_checkpoint)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
trial_state.last_perturbation_time = new_state.last_perturbation_time
def _explore(self, hyperparams, hyperparameter_mutations, best_trial):
if hyperparameter_mutations is not None:
hyperparams = {
param: random.choice(hyperparameter_mutations[param])
for param in hyperparams
if param != "env" and param in hyperparameter_mutations
}
for param in best_trial.config:
if param not in hyperparameter_mutations and param != "env":
hyperparams[param] = math.ceil(
(best_trial.config[param]
* random.choice([0.8, 1.2])/2.)) * 2
def _quantiles(self):
"""Returns trials in the lower and upper `quantile` of the population.
If there is not enough data to compute this, returns empty lists."""
trials = []
for trial, state in self._trial_state.items():
if state.last_score is not None and not trial.is_finished():
trials.append(trial)
trials.sort(key=lambda t: self._trial_state[t].last_score)
if len(trials) <= 1:
return [], []
else:
hyperparams = {
param: math.ceil(
(random.choice([0.8, 1.2]) *
hyperparams[param])/2.) * 2
for param in hyperparams
if param != "env"
}
hyperparams["env"] = best_trial.config["env"]
return hyperparams
return (
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
def _truncation_ready(self, result, trial, time):
# function checks if appropriate time has passed
# and trial is in the bottom 20% of all trials, and if so, is ready
if trial not in self._last_perturbation_time:
print("added trial to time tracker")
self._last_perturbation_time[trial] = (time)
else:
time_since_last = time - self._last_perturbation_time[trial]
if time_since_last >= self._perturbation_interval:
self._last_perturbation_time[trial] = time
sorted_result_keys = sorted(
self._results, key=lambda x:
max(self._results.get(x) if self._results.get(x) else [0])
)
max_index = int(round(len(sorted_result_keys) * 0.2))
for i in range(0, max_index):
if trial == sorted_result_keys[i]:
print("{0} is in the bottomn 20 percent of {1}, \
truncation is ready".format(
trial,
[x.experiment_tag for x in sorted_result_keys]
))
return True
print("{0} is not in the bottomn 20 percent of {1}, \
truncation is not ready".format(
trial,
[x.experiment_tag for x in sorted_result_keys]
))
else:
print("not enough time has passed since last mutation")
return False
def choose_trial_to_run(self, trial_runner):
"""Ensures all trials get fair share of time (as defined by time_attr).
def _get_best_trial(self, result, time):
results_at_time = {}
for trial in self._results:
results_at_time[trial] = [
getattr(r, self._reward_attr)
for r in self._results[trial]
if getattr(r, self._time_attr) <= time
]
print("Results at {0}: {1}".format(time, results_at_time))
return max(results_at_time, key=lambda x:
max(results_at_time.get(x)
if results_at_time.get(x) else [0]))
This enables the PBT scheduler to support a greater number of
concurrent trials than can fit in the cluster at any given time.
"""
def _is_empty(self, x):
if x:
return False
return True
candidates = []
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
trial_runner.has_resources(trial.resources):
candidates.append(trial)
candidates.sort(
key=lambda trial: self._trial_state[trial].last_perturbation_time)
return candidates[0] if candidates else None
def reset_stats(self):
self._num_perturbations = 0
self._num_checkpoints = 0
def last_scores(self, trials):
scores = []
for trial in trials:
state = self._trial_state[trial]
if state.last_score is not None and not trial.is_finished():
scores.append(state.last_score)
return scores
def debug_string(self):
min_time = 0
best_trial = None
for trial in self._completed_trials:
last_result = self._results[trial][-1]
if (getattr(last_result, self._time_attr)
< min_time or min_time == 0):
min_time = getattr(last_result, self._time_attr)
best_trial = trial
if best_trial is not None:
return ("The Best Trial is currently {0} finishing in {1} iterations, \
with the hyperparameters of {2}".format(
best_trial, min_time, best_trial.config
)
)
else:
return "PBT has started"
return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
self._num_checkpoints, self._num_perturbations)

View file

@ -84,10 +84,14 @@ TrainingResult = namedtuple("TrainingResult", [
# (Auto-filled) The hostname of the machine hosting the training process.
"hostname",
# (Auto=filled) The current hyperparameter configuration.
"config",
])
def pretty_print(result):
result = result._replace(config=None) # drop config from pretty print
out = {}
for k, v in result._asdict().items():
if v is not None:

View file

@ -13,7 +13,7 @@ from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Trial, Resources, MAX_LEN_IDENTIFIER
from ray.tune.trial import Trial, Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.variant_generator import generate_trials, grid_search, \
RecursiveDependencyError
@ -364,8 +364,9 @@ class TrialRunnerTest(unittest.TestCase):
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
self.assertLessEqual(
len(str(trial)), MAX_LEN_IDENTIFIER)
trial.start()
self.assertLessEqual(len(trial.logdir), 200)
trial.stop()
def testTrialErrorOnStart(self):
ray.init()

View file

@ -5,15 +5,17 @@ from __future__ import print_function
import unittest
import numpy as np
import random
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.pbt import PopulationBasedTraining
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.result import TrainingResult
from ray.tune.trial import Trial
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import TrialScheduler
from ray.rllib import _register_all
_register_all()
def result(t, rew):
return TrainingResult(time_total_s=t,
@ -145,6 +147,7 @@ class EarlyStoppingSuite(unittest.TestCase):
class _MockTrialRunner():
def __init__(self, scheduler):
self._scheduler_alg = scheduler
self.trials = []
def process_action(self, trial, action):
if action == TrialScheduler.CONTINUE:
@ -163,6 +166,13 @@ class _MockTrialRunner():
self._scheduler_alg.on_trial_complete(self, trial, result(100, 10))
def add_trial(self, trial):
self.trials.append(trial)
self._scheduler_alg.on_trial_add(self, trial)
def get_trials(self):
return self.trials
def has_resources(self, resources):
return True
@ -511,109 +521,202 @@ class HyperbandSuite(unittest.TestCase):
self.assertFalse(trial in bracket._live_trials)
class _MockTrialRunnerPBT(_MockTrialRunner):
def __init__(self):
self._trials = []
def _launch_trial(self, trial):
trial.status = Trial.RUNNING
self._trials.append(trial)
class _MockTrialPBT(Trial):
class _MockTrial(Trial):
def __init__(self, i, config):
self.trainable_name = "trial_{}".format(i)
self.config = config
self.experiment_tag = "tag"
self.logger_running = False
self.restored_checkpoint = None
self.resources = Resources(1, 0)
def checkpoint(self, to_object_store=False):
return 'checkpointed'
return self.trainable_name
def start(self):
return 'started'
def start(self, checkpoint=None):
self.logger_running = True
self.restored_checkpoint = checkpoint
def stop(self):
return 'stopped'
def stop(self, stop_logger=False):
if stop_logger:
self.logger_running = False
class PopulationBasedTestingSuite(unittest.TestCase):
def schedulerSetup(self, num_trials):
sched = PopulationBasedTraining()
runner = _MockTrialRunnerPBT()
for i in range(num_trials):
t = _MockTrialPBT("__parameter_tuning")
t.config = {'test': 1, 'test1': 1, 'env': 'test'}
t.experiment_tag = str(i)
runner._launch_trial(t)
return sched, runner
def basicSetup(self, resample_prob=0.0, explore=None):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=10,
resample_probability=resample_prob,
hyperparam_mutations={
"id_factor": [100],
"float_factor": lambda c: 100.0,
"int_factor": lambda c: 10,
},
custom_explore_fn=explore)
runner = _MockTrialRunner(pbt)
for i in range(5):
trial = _MockTrial(
i,
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
"int_factor": 10})
runner.add_trial(trial)
trial.status = Trial.RUNNING
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
pbt.reset_stats()
return pbt, runner
def testReadyFunction(self):
sched, runner = self.schedulerSetup(5)
# different time intervals to test at
best_result_early = result(18, 100)
best_result_late = result(25, 100)
runner._trials[0].config = {'test': 10, 'test1': 10, 'env': 'test'}
# setting up best trial so that it consistently is the best trial
sched.on_trial_result(runner, runner._trials[0], result(11, 0))
sched.on_trial_result(runner, runner._trials[0], result(14, 2))
sched.on_trial_result(runner, runner._trials[0], best_result_early)
sched.on_trial_result(runner, runner._trials[0], best_result_late)
# testing that adding trials to time tracker works, and that
# ready function knows when to start
for trial in runner._trials[1:]:
old_config = trial.config
sched.on_trial_result(
runner, trial, result(11, random.randint(0, 10)))
self.assertTrue(old_config == trial.config)
# making sure that the second trial in runner._trials
# (not the best trial) is the worst trial
for trial in runner._trials[2:]:
# testing to see that ready function knows
# that not enough time has passed
sched.on_trial_result(
runner, trial, result(16, random.randint(40, 50)))
# testing to see if worst trial (aka bottom 20%)
# has mutated (ready function initiated)
old_config = runner._trials[1].config
sched.on_trial_result(runner, runner._trials[1], result(26, 30))
self.assertFalse(old_config == runner._trials[1].config)
def testCheckpointsMostPromisingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
def testExploitExploreFunction(self):
sched, runner = self.schedulerSetup(5)
# different time intervals to test at
best_result_early = result(18, 100)
best_result_late = result(25, 100)
runner._trials[0].config = {'test': 10, 'test1': 10, 'env': 'test'}
# setting up best trial so that it consistently is the best trial
sched.on_trial_result(runner, runner._trials[0], best_result_early)
sched.on_trial_result(runner, runner._trials[0], best_result_late)
# testing that adding trials to time tracker works, and
# that ready function knows when to start
for trial in runner._trials[1:]:
sched.on_trial_result(
runner, trial, result(11, random.randint(0, 10)))
# making sure that the second trial in runner._trials
# (not the best trial) is the worst trial
for trial in runner._trials[2:]:
sched.on_trial_result(
runner, trial, result(16, random.randint(40, 50)))
sched.on_trial_result(runner, runner._trials[1], result(26, 30))
# make sure mutated values are multiples of 0.8 and 1.2
# (default explore values)
for key in runner._trials[0].config:
if key == 'env':
continue
else:
if (
runner._trials[1].config[key] == 0.8 *
runner._trials[0].config[key] or
runner._trials[1].config[key] == 1.2 *
runner._trials[0].config[key]
):
continue
else:
raise ValueError('Trial not correctly explored (mutated)')
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# checkpoint: both past interval and upper quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 1)
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(30, 201)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 2)
# not upper quantile any more
self.assertEqual(
pbt.on_trial_result(runner, trials[4], result(30, 199)),
TrialScheduler.CONTINUE)
self.assertEqual(pbt._num_checkpoints, 2)
self.assertEqual(pbt._num_perturbations, 0)
def testPerturbsLowPerformingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
# no perturbation: haven't hit next perturbation interval
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
# perturb since it's lower quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertTrue("@perturbed" in trials[0].experiment_tag)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(pbt._num_perturbations, 1)
# also perturbed
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(20, 40)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt._num_perturbations, 2)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertTrue("@perturbed" in trials[2].experiment_tag)
def testPerturbWithoutResample(self):
pbt, runner = self.basicSetup(resample_prob=0.0)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertIn(trials[0].config["id_factor"], [3, 4])
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
self.assertEqual(type(trials[0].config["float_factor"]), float)
self.assertIn(trials[0].config["int_factor"], [8, 12])
self.assertEqual(type(trials[0].config["int_factor"]), int)
self.assertEqual(trials[0].config["const_factor"], 3)
def testPerturbWithResample(self):
pbt, runner = self.basicSetup(resample_prob=1.0)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(trials[0].config["id_factor"], 100)
self.assertEqual(trials[0].config["float_factor"], 100.0)
self.assertEqual(type(trials[0].config["float_factor"]), float)
self.assertEqual(trials[0].config["int_factor"], 10)
self.assertEqual(type(trials[0].config["int_factor"]), int)
self.assertEqual(trials[0].config["const_factor"], 3)
def testYieldsTimeToOtherTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
trials[0].status = Trial.PENDING # simulate not enough resources
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
TrialScheduler.PAUSE)
self.assertEqual(
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
def testSchedulesMostBehindTrialToRun(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
pbt.on_trial_result(runner, trials[0], result(800, 1000))
pbt.on_trial_result(runner, trials[1], result(700, 1001))
pbt.on_trial_result(runner, trials[2], result(600, 1002))
pbt.on_trial_result(runner, trials[3], result(500, 1003))
pbt.on_trial_result(runner, trials[4], result(700, 1004))
self.assertEqual(pbt.choose_trial_to_run(runner), None)
for i in range(5):
trials[i].status = Trial.PENDING
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
def testPerturbationResetsLastPerturbTime(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
pbt.on_trial_result(runner, trials[0], result(10000, 1005))
pbt.on_trial_result(runner, trials[1], result(10000, 1004))
pbt.on_trial_result(runner, trials[2], result(600, 1003))
self.assertEqual(pbt._num_perturbations, 0)
pbt.on_trial_result(runner, trials[3], result(500, 1002))
self.assertEqual(pbt._num_perturbations, 1)
pbt.on_trial_result(runner, trials[3], result(600, 100))
self.assertEqual(pbt._num_perturbations, 1)
pbt.on_trial_result(runner, trials[3], result(11000, 100))
self.assertEqual(pbt._num_perturbations, 2)
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42
new_config["float_factor"] = 43
return new_config
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
trials = runner.get_trials()
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(trials[0].config["id_factor"], 42)
self.assertEqual(trials[0].config["float_factor"], 43)
if __name__ == "__main__":
from ray.rllib import _register_all
_register_all()
unittest.main(verbosity=2)

View file

@ -135,7 +135,8 @@ class Trainable(object):
time_total_s=self._time_total,
neg_mean_loss=neg_loss,
pid=os.getpid(),
hostname=os.uname()[1])
hostname=os.uname()[1],
config=self.config)
self._result_logger.on_result(result)
@ -185,8 +186,8 @@ class Trainable(object):
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
if len(compressed) > 10e6: # getting pretty large
print("Checkpoint size is {} bytes".format(len(compressed)))
f.write(compressed)
shutil.rmtree(tmpdir)

View file

@ -96,7 +96,7 @@ class Trial(object):
"Stopping condition key `{}` must be one of {}".format(
k, TrainingResult._fields))
# Immutable config
# Trial config
self.trainable_name = trainable_name
self.config = config or {}
self.local_dir = local_dir
@ -105,6 +105,7 @@ class Trial(object):
self.stopping_criterion = stopping_criterion or {}
self.checkpoint_freq = checkpoint_freq
self.upload_dir = upload_dir
self.verbose = True
# Local trial state that is updated during the run
self.last_result = None
@ -117,16 +118,22 @@ class Trial(object):
self.result_logger = None
self.last_debug = 0
self.trial_id = binary_to_hex(random_string())[:8]
self.error_file = None
def start(self):
def start(self, checkpoint_obj=None):
"""Starts this trial.
If an error is encountered when starting the trial, an exception will
be thrown.
Args:
checkpoint_obj (obj): Optional checkpoint to resume from.
"""
self._setup_runner()
if self._checkpoint_path:
if checkpoint_obj:
self.restore_from_obj(checkpoint_obj)
elif self._checkpoint_path:
self.restore_from_path(self._checkpoint_path)
elif self._checkpoint_obj:
self.restore_from_obj(self._checkpoint_obj)
@ -155,6 +162,7 @@ class Trial(object):
self.logdir, "error_{}.txt".format(date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
self.error_file = error_file
if self.runner:
stop_tasks = []
stop_tasks.append(self.runner.stop.remote())
@ -163,9 +171,6 @@ class Trial(object):
# TODO(ekl) seems like wait hangs when killing actors
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
if unfinished:
print(("Stopping %s Actor timed out, "
"but moving on...") % self)
except Exception:
print("Error stopping runner:", traceback.format_exc())
self.status = Trial.ERROR
@ -230,7 +235,7 @@ class Trial(object):
"""Returns a progress message for printing out to the console."""
if self.last_result is None:
return self.status
return self._status_string()
def location_string(hostname, pid):
if hostname == os.uname()[1]:
@ -240,7 +245,8 @@ class Trial(object):
pieces = [
'{} [{}]'.format(
self.status, location_string(
self._status_string(),
location_string(
self.last_result.hostname, self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)),
'{} ts'.format(int(self.last_result.timesteps_total))]
@ -259,6 +265,11 @@ class Trial(object):
return ', '.join(pieces)
def _status_string(self):
return "{}{}".format(
self.status,
" => {}".format(self.error_file) if self.error_file else "")
def checkpoint(self, to_object_store=False):
"""Checkpoints the state of this trial.
@ -276,7 +287,8 @@ class Trial(object):
self._checkpoint_path = path
self._checkpoint_obj = obj
print("Saved checkpoint to:", path or obj)
if self.verbose:
print("Saved checkpoint for {} to {}".format(self, path or obj))
return path or obj
def restore_from_path(self, path):
@ -310,7 +322,9 @@ class Trial(object):
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
if terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL:
if terminate or (
self.verbose and
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
@ -348,12 +362,17 @@ class Trial(object):
config=self.config, registry=get_registry(),
logger_creator=logger_creator)
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
def set_verbose(self, verbose):
self.verbose = verbose
Truncates to MAX_LEN_IDENTIFIER (default is 130) to avoid problems
when creating logging directories.
"""
def is_finished(self):
return self.status in [Trial.TERMINATED, Trial.ERROR]
def __repr__(self):
return str(self)
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
if "env" in self.config:
identifier = "{}_{}".format(
self.trainable_name, self.config["env"])
@ -361,4 +380,4 @@ class Trial(object):
identifier = self.trainable_name
if self.experiment_tag:
identifier += "_" + self.experiment_tag
return identifier[:MAX_LEN_IDENTIFIER]
return identifier

View file

@ -31,7 +31,7 @@ def _make_scheduler(args):
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT):
server_port=TuneServer.DEFAULT_PORT, verbose=True):
# Make sure rllib agents are registered
from ray import rllib # noqa # pylint: disable=unused-import
@ -44,6 +44,7 @@ def run_experiments(experiments, scheduler=None, with_server=False,
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
trial.set_verbose(verbose)
runner.add_trial(trial)
print(runner.debug_string(max_debug=99999))

View file

@ -164,7 +164,15 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/tune_mnist_ray.py \
--fast
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/pbt_example.py \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/hyperband_example.py \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/multiagent_mountaincar.py