[tune] Added Population Based Training (#1355)

Adds a Population-Based Training (as described in https://arxiv.org/abs/1711.09846) scheduler to Ray.tune. Currently mutates hyperparameters according to either a user-defined list of possible values to mutate to (necessary if hyperparameters can only be certain values ex. sgd_batch_size), or by a factor of 0.8 or 1.2.
This commit is contained in:
Kaahan 2018-01-25 21:38:37 -08:00 committed by Eric Liang
parent e5c4d9ea0c
commit 7aa979a024
5 changed files with 320 additions and 4 deletions

View file

@ -8,8 +8,8 @@ from ray.tune.registry import register_trainable
def _register_all():
for key in [
"PPO", "ES", "DQN", "A3C", "BC", "__fake", "__sigmoid_fake_data"]:
for key in ["PPO", "ES", "DQN", "A3C", "BC", "__fake",
"__sigmoid_fake_data", "__parameter_tuning"]:
try:
from ray.rllib.agent import get_agent_class
register_trainable(key, get_agent_class(key))

View file

@ -190,6 +190,25 @@ class _SigmoidFakeData(_MockAgent):
time_this_iter_s=self.config["iter_time"], info={})
class _ParameterTuningAgent(_MockAgent):
_agent_name = "ParameterTuningAgent"
_default_config = {
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1
}
def _train(self):
return TrainingResult(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"], info={})
def get_agent_class(alg):
"""Returns the class of an known agent given its name."""
@ -215,6 +234,8 @@ def get_agent_class(alg):
return _MockAgent
elif alg == "__sigmoid_fake_data":
return _SigmoidFakeData
elif alg == "__parameter_tuning":
return _ParameterTuningAgent
else:
raise Exception(
("Unknown algorithm {}.").format(alg))

190
python/ray/tune/pbt.py Normal file
View file

@ -0,0 +1,190 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random
import math
import copy
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
class PopulationBasedTraining(FIFOScheduler):
"""Implements the Population Based Training algorithm as described in the
PBT paper (https://arxiv.org/abs/1711.09846)(Experimental):
Args:
time_attr (str): The TrainingResult attr to use for documenting length
of time since last ready() call. Attribute only has to increase
monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with 'time_attr'. this may refer to any objective value that
is supposed to increase with time.
grace_period (float): Period of time, in which algorithm will not
compare model to other models.
perturbation_interval (float): Used in the truncation ready function to
determine if enough time has passed so that a agent can be tested
for readiness.
hyperparameter_mutations (dict); Possible values that each
hyperparameter can mutate to, as certain hyperparameters
only work with certain values.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean',
grace_period=10.0, perturbation_interval=6.0,
hyperparameter_mutations=None):
FIFOScheduler.__init__(self)
self._completed_trials = set()
self._results = collections.defaultdict(list)
self._last_perturbation_time = {}
self._grace_period = grace_period
self._reward_attr = reward_attr
self._time_attr = time_attr
self._hyperparameter_mutations = hyperparameter_mutations
self._perturbation_interval = perturbation_interval
self._checkpoint_paths = {}
def on_trial_result(self, trial_runner, trial, result):
self._results[trial].append(result)
time = getattr(result, self._time_attr)
# check model is ready to undergo mutation, based on user
# function or default function
self._checkpoint_paths[trial] = trial.checkpoint()
if time > self._grace_period:
ready = self._truncation_ready(result, trial, time)
else:
ready = False
if ready:
print("ready to undergo mutation")
print("----")
print("Current Trial is: {0}".format(trial))
# get best trial for current time
best_trial = self._get_best_trial(result, time)
print("Best Trial is: {0}".format(best_trial))
print(best_trial.config)
# if current trial is the best trial (as in same hyperparameters),
# do nothing
if trial.config == best_trial.config:
print("current trial is best trial")
return TrialScheduler.CONTINUE
else:
self._exploit(self._hyperparameter_mutations, best_trial,
trial, trial_runner, time)
return TrialScheduler.CONTINUE
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self._results[trial].append(result)
self._completed_trials.add(trial)
def _exploit(self, hyperparameter_mutations, best_trial,
trial, trial_runner, time):
trial.stop()
mutate_string = "_mutated@" + str(time)
hyperparams = copy.deepcopy(best_trial.config)
hyperparams = self._explore(hyperparams, hyperparameter_mutations,
best_trial)
print("new hyperparameter configuration: {0}".format(hyperparams))
checkpoint = self._checkpoint_paths[best_trial]
trial._checkpoint_path = checkpoint
trial.config = hyperparams
trial.experiment_tag = trial.experiment_tag + mutate_string
trial.start()
def _explore(self, hyperparams, hyperparameter_mutations, best_trial):
if hyperparameter_mutations is not None:
hyperparams = {
param: random.choice(hyperparameter_mutations[param])
for param in hyperparams
if param != "env" and param in hyperparameter_mutations
}
for param in best_trial.config:
if param not in hyperparameter_mutations and param != "env":
hyperparams[param] = math.ceil(
(best_trial.config[param]
* random.choice([0.8, 1.2])/2.)) * 2
else:
hyperparams = {
param: math.ceil(
(random.choice([0.8, 1.2]) *
hyperparams[param])/2.) * 2
for param in hyperparams
if param != "env"
}
hyperparams["env"] = best_trial.config["env"]
return hyperparams
def _truncation_ready(self, result, trial, time):
# function checks if appropriate time has passed
# and trial is in the bottom 20% of all trials, and if so, is ready
if trial not in self._last_perturbation_time:
print("added trial to time tracker")
self._last_perturbation_time[trial] = (time)
else:
time_since_last = time - self._last_perturbation_time[trial]
if time_since_last >= self._perturbation_interval:
self._last_perturbation_time[trial] = time
sorted_result_keys = sorted(
self._results, key=lambda x:
max(self._results.get(x) if self._results.get(x) else [0])
)
max_index = int(round(len(sorted_result_keys) * 0.2))
for i in range(0, max_index):
if trial == sorted_result_keys[i]:
print("{0} is in the bottomn 20 percent of {1}, \
truncation is ready".format(
trial,
[x.experiment_tag for x in sorted_result_keys]
))
return True
print("{0} is not in the bottomn 20 percent of {1}, \
truncation is not ready".format(
trial,
[x.experiment_tag for x in sorted_result_keys]
))
else:
print("not enough time has passed since last mutation")
return False
def _get_best_trial(self, result, time):
results_at_time = {}
for trial in self._results:
results_at_time[trial] = [
getattr(r, self._reward_attr)
for r in self._results[trial]
if getattr(r, self._time_attr) <= time
]
print("Results at {0}: {1}".format(time, results_at_time))
return max(results_at_time, key=lambda x:
max(results_at_time.get(x)
if results_at_time.get(x) else [0]))
def _is_empty(self, x):
if x:
return False
return True
def debug_string(self):
min_time = 0
best_trial = None
for trial in self._completed_trials:
last_result = self._results[trial][-1]
if (getattr(last_result, self._time_attr)
< min_time or min_time == 0):
min_time = getattr(last_result, self._time_attr)
best_trial = trial
if best_trial is not None:
return ("The Best Trial is currently {0} finishing in {1} iterations, \
with the hyperparameters of {2}".format(
best_trial, min_time, best_trial.config
)
)
else:
return "PBT has started"

View file

@ -188,8 +188,8 @@ class TrialRunner(object):
trial = self._get_runnable()
return trial is not None
def _launch_trial(self):
trial = self._get_runnable()
def _launch_trial(self, custom_trial=None):
trial = custom_trial or self._get_runnable()
self._commit_resources(trial.resources)
try:
trial.start()

View file

@ -1,11 +1,14 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
import random
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.pbt import PopulationBasedTraining
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.result import TrainingResult
from ray.tune.trial import Trial
@ -508,6 +511,108 @@ class HyperbandSuite(unittest.TestCase):
self.assertFalse(trial in bracket._live_trials)
class _MockTrialRunnerPBT(_MockTrialRunner):
def __init__(self):
self._trials = []
def _launch_trial(self, trial):
trial.status = Trial.RUNNING
self._trials.append(trial)
class _MockTrialPBT(Trial):
def checkpoint(self, to_object_store=False):
return 'checkpointed'
def start(self):
return 'started'
def stop(self):
return 'stopped'
class PopulationBasedTestingSuite(unittest.TestCase):
def schedulerSetup(self, num_trials):
sched = PopulationBasedTraining()
runner = _MockTrialRunnerPBT()
for i in range(num_trials):
t = _MockTrialPBT("__parameter_tuning")
t.config = {'test': 1, 'test1': 1, 'env': 'test'}
t.experiment_tag = str(i)
runner._launch_trial(t)
return sched, runner
def testReadyFunction(self):
sched, runner = self.schedulerSetup(5)
# different time intervals to test at
best_result_early = result(18, 100)
best_result_late = result(25, 100)
runner._trials[0].config = {'test': 10, 'test1': 10, 'env': 'test'}
# setting up best trial so that it consistently is the best trial
sched.on_trial_result(runner, runner._trials[0], result(11, 0))
sched.on_trial_result(runner, runner._trials[0], result(14, 2))
sched.on_trial_result(runner, runner._trials[0], best_result_early)
sched.on_trial_result(runner, runner._trials[0], best_result_late)
# testing that adding trials to time tracker works, and that
# ready function knows when to start
for trial in runner._trials[1:]:
old_config = trial.config
sched.on_trial_result(
runner, trial, result(11, random.randint(0, 10)))
self.assertTrue(old_config == trial.config)
# making sure that the second trial in runner._trials
# (not the best trial) is the worst trial
for trial in runner._trials[2:]:
# testing to see that ready function knows
# that not enough time has passed
sched.on_trial_result(
runner, trial, result(16, random.randint(40, 50)))
# testing to see if worst trial (aka bottom 20%)
# has mutated (ready function initiated)
old_config = runner._trials[1].config
sched.on_trial_result(runner, runner._trials[1], result(26, 30))
self.assertFalse(old_config == runner._trials[1].config)
def testExploitExploreFunction(self):
sched, runner = self.schedulerSetup(5)
# different time intervals to test at
best_result_early = result(18, 100)
best_result_late = result(25, 100)
runner._trials[0].config = {'test': 10, 'test1': 10, 'env': 'test'}
# setting up best trial so that it consistently is the best trial
sched.on_trial_result(runner, runner._trials[0], best_result_early)
sched.on_trial_result(runner, runner._trials[0], best_result_late)
# testing that adding trials to time tracker works, and
# that ready function knows when to start
for trial in runner._trials[1:]:
sched.on_trial_result(
runner, trial, result(11, random.randint(0, 10)))
# making sure that the second trial in runner._trials
# (not the best trial) is the worst trial
for trial in runner._trials[2:]:
sched.on_trial_result(
runner, trial, result(16, random.randint(40, 50)))
sched.on_trial_result(runner, runner._trials[1], result(26, 30))
# make sure mutated values are multiples of 0.8 and 1.2
# (default explore values)
for key in runner._trials[0].config:
if key == 'env':
continue
else:
if (
runner._trials[1].config[key] == 0.8 *
runner._trials[0].config[key] or
runner._trials[1].config[key] == 1.2 *
runner._trials[0].config[key]
):
continue
else:
raise ValueError('Trial not correctly explored (mutated)')
if __name__ == "__main__":
from ray.rllib import _register_all
_register_all()