mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] Allow perturbations of categorical variables (#1507)
* categorical perturb * Sat Feb 3 00:28:13 PST 2018 * explicitly clean up nested actors * Sat Feb 10 02:21:57 PST 2018 * Sat Feb 10 02:22:29 PST 2018
This commit is contained in:
parent
639df85fda
commit
b6a06b81ed
4 changed files with 78 additions and 26 deletions
|
@ -68,8 +68,8 @@ if __name__ == "__main__":
|
|||
hyperparam_mutations={
|
||||
# Allow for scaling-based perturbations, with a uniform backing
|
||||
# distribution for resampling.
|
||||
"factor_1": lambda config: random.uniform(0.0, 20.0),
|
||||
# Only allows resampling from this list as a perturbation.
|
||||
"factor_1": lambda: random.uniform(0.0, 20.0),
|
||||
# Allow perturbations within this set of categorical values.
|
||||
"factor_2": [1, 2],
|
||||
})
|
||||
|
||||
|
|
|
@ -33,15 +33,14 @@ if __name__ == "__main__":
|
|||
time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=120,
|
||||
resample_probability=0.25,
|
||||
# Specifies the resampling distributions of these hyperparams
|
||||
# Specifies the mutations of these hyperparams
|
||||
hyperparam_mutations={
|
||||
"lambda": lambda config: random.uniform(0.9, 1.0),
|
||||
"clip_param": lambda config: random.uniform(0.01, 0.5),
|
||||
"sgd_stepsize": lambda config: random.uniform(.00001, .001),
|
||||
"num_sgd_iter": lambda config: random.randint(1, 30),
|
||||
"sgd_batchsize": lambda config: random.randint(128, 16384),
|
||||
"timesteps_per_batch":
|
||||
lambda config: random.randint(2000, 160000),
|
||||
"lambda": lambda: random.uniform(0.9, 1.0),
|
||||
"clip_param": lambda: random.uniform(0.01, 0.5),
|
||||
"sgd_stepsize": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
|
||||
"num_sgd_iter": lambda: random.randint(1, 30),
|
||||
"sgd_batchsize": lambda: random.randint(128, 16384),
|
||||
"timesteps_per_batch": lambda: random.randint(2000, 160000),
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
|
||||
|
@ -57,11 +56,11 @@ if __name__ == "__main__":
|
|||
"num_workers": 8,
|
||||
"devices": ["/gpu:0"],
|
||||
"model": {"free_log_std": True},
|
||||
# These params are tuned from their starting value
|
||||
# These params are tuned from a fixed starting value.
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
# Start off with several random variations
|
||||
"sgd_stepsize": lambda spec: random.uniform(.00001, .001),
|
||||
"sgd_stepsize": 1e-4,
|
||||
# These params start off randomly drawn from a set.
|
||||
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
|
|
|
@ -47,11 +47,19 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
|||
new_config = copy.deepcopy(config)
|
||||
for key, distribution in mutations.items():
|
||||
if isinstance(distribution, list):
|
||||
if random.random() < resample_probability:
|
||||
if random.random() < resample_probability or \
|
||||
config[key] not in distribution:
|
||||
new_config[key] = random.choice(distribution)
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = distribution[
|
||||
max(0, distribution.index(config[key]) - 1)]
|
||||
else:
|
||||
new_config[key] = distribution[
|
||||
min(len(distribution) - 1,
|
||||
distribution.index(config[key]) + 1)]
|
||||
else:
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = distribution(config)
|
||||
new_config[key] = distribution()
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = config[key] * 1.2
|
||||
else:
|
||||
|
@ -109,14 +117,14 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
to be too frequent.
|
||||
hyperparam_mutations (dict): Hyperparams to mutate. The format is
|
||||
as follows: for each key, either a list or function can be
|
||||
provided. A list specifies values for a discrete parameter.
|
||||
provided. A list specifies an allowed set of categorical values.
|
||||
A function specifies the distribution of a continuous parameter.
|
||||
You must specify at least one of `hyperparam_mutations` or
|
||||
`custom_explore_fn`.
|
||||
resample_probability (float): The probability of resampling from the
|
||||
original distribution when applying `hyperparam_mutations`. If not
|
||||
resampled, the value will be perturbed by a factor of 1.2 or 0.8
|
||||
if continuous, or left unchanged if discrete.
|
||||
if continuous, or changed to an adjacent value if discrete.
|
||||
custom_explore_fn (func): You can also specify a custom exploration
|
||||
function. This function is invoked as `f(config)` after built-in
|
||||
perturbations from `hyperparam_mutations` are applied, and should
|
||||
|
@ -130,11 +138,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
>>> perturbation_interval=10, # every 10 `time_attr` units
|
||||
>>> # (training_iterations in this case)
|
||||
>>> hyperparam_mutations={
|
||||
>>> # Allow for scaling-based perturbations, with a uniform
|
||||
>>> # backing distribution for resampling.
|
||||
>>> "factor_1": lambda config: random.uniform(0.0, 20.0),
|
||||
>>> # Only allows resampling from this list as a perturbation.
|
||||
>>> "factor_2": [1, 2],
|
||||
>>> # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
|
||||
>>> # resets it to a value sampled from the lambda function.
|
||||
>>> "factor_1": lambda: random.uniform(0.0, 20.0),
|
||||
>>> # Perturb factor2 by changing it to an adjacent value, e.g.
|
||||
>>> # 10 -> 1 or 10 -> 100. Resampling will choose at random.
|
||||
>>> "factor_2": [1, 10, 100, 1000, 10000],
|
||||
>>> })
|
||||
>>> run_experiments({...}, scheduler=pbt)
|
||||
"""
|
||||
|
|
|
@ -3,11 +3,12 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
from ray.tune.pbt import PopulationBasedTraining, explore
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Trial, Resources
|
||||
|
@ -551,8 +552,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
resample_probability=resample_prob,
|
||||
hyperparam_mutations={
|
||||
"id_factor": [100],
|
||||
"float_factor": lambda c: 100.0,
|
||||
"int_factor": lambda c: 10,
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
|
@ -644,7 +645,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertIn(trials[0].config["id_factor"], [3, 4])
|
||||
self.assertIn(trials[0].config["id_factor"], [100])
|
||||
self.assertIn(trials[0].config["float_factor"], [2.4, 1.6])
|
||||
self.assertEqual(type(trials[0].config["float_factor"]), float)
|
||||
self.assertIn(trials[0].config["int_factor"], [8, 12])
|
||||
|
@ -665,6 +666,49 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
self.assertEqual(type(trials[0].config["int_factor"]), int)
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testPerturbationValues(self):
|
||||
|
||||
def assertProduces(fn, values):
|
||||
random.seed(0)
|
||||
seen = set()
|
||||
for _ in range(100):
|
||||
seen.add(fn()["v"])
|
||||
self.assertEqual(seen, values)
|
||||
|
||||
# Categorical case
|
||||
assertProduces(
|
||||
lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
|
||||
set([3, 8]))
|
||||
assertProduces(
|
||||
lambda: explore({"v": 3}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
|
||||
set([3, 4]))
|
||||
assertProduces(
|
||||
lambda: explore({"v": 10}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
|
||||
set([8, 10]))
|
||||
assertProduces(
|
||||
lambda: explore({"v": 7}, {"v": [3, 4, 8, 10]}, 0.0, lambda x: x),
|
||||
set([3, 4, 8, 10]))
|
||||
assertProduces(
|
||||
lambda: explore({"v": 4}, {"v": [3, 4, 8, 10]}, 1.0, lambda x: x),
|
||||
set([3, 4, 8, 10]))
|
||||
|
||||
# Continuous case
|
||||
assertProduces(
|
||||
lambda: explore(
|
||||
{"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0,
|
||||
lambda x: x),
|
||||
set([80, 120]))
|
||||
assertProduces(
|
||||
lambda: explore(
|
||||
{"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0,
|
||||
lambda x: x),
|
||||
set([80.0, 120.0]))
|
||||
assertProduces(
|
||||
lambda: explore(
|
||||
{"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0,
|
||||
lambda x: x),
|
||||
set([10.0, 100.0]))
|
||||
|
||||
def testYieldsTimeToOtherTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
|
Loading…
Add table
Reference in a new issue