[tune] Tweak/allow nested pbt mutations (#3455)

* Fix warning text in pbt logger

* Allow nested mutations in pbt by recursing explore function

* Add test for nested pbt mutation

* Update pbt explore to only call custom explore on top level

* fix test
This commit is contained in:
Kristian Hartikainen 2019-01-04 13:51:11 -08:00 committed by Richard Liaw
parent cd80891ddb
commit 747b117929
2 changed files with 112 additions and 4 deletions

View file

@ -47,7 +47,12 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
"""
new_config = copy.deepcopy(config)
for key, distribution in mutations.items():
if isinstance(distribution, list):
if isinstance(distribution, dict):
new_config.update({
key: explore(config[key], mutations[key], resample_probability,
None)
})
elif isinstance(distribution, list):
if random.random() < resample_probability or \
config[key] not in distribution:
new_config[key] = random.choice(distribution)
@ -213,8 +218,8 @@ class PopulationBasedTraining(FIFOScheduler):
trial_state = self._trial_state[trial]
new_state = self._trial_state[trial_to_clone]
if not new_state.last_checkpoint:
logger.warning("[pbt]: no checkpoint for trial"
"skip exploit for Trial {}".format(trial))
logger.warning("[pbt]: no checkpoint for trial."
" Skip exploit for Trial {}".format(trial))
return
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability,

View file

@ -5,7 +5,7 @@ from __future__ import print_function
import random
import unittest
import numpy as np
import sys
import ray
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
@ -17,6 +17,11 @@ from ray.tune.trial_executor import TrialExecutor
from ray.rllib import _register_all
_register_all()
if sys.version_info >= (3, 3):
from unittest.mock import MagicMock
else:
from mock import MagicMock
def result(t, rew):
return dict(
@ -748,6 +753,104 @@ class PopulationBasedTestingSuite(unittest.TestCase):
lambda x: x),
{10.0, 100.0})
def deep_add(seen, new_values):
for k, new_value in new_values.items():
if isinstance(new_value, dict):
if k not in seen:
seen[k] = {}
seen[k].update(deep_add(seen[k], new_value))
else:
if k not in seen:
seen[k] = set()
seen[k].add(new_value)
return seen
def assertNestedProduces(fn, values):
random.seed(0)
seen = {}
for _ in range(100):
new_config = fn()
seen = deep_add(seen, new_config)
self.assertEqual(seen, values)
# Nested mutation and spec
assertNestedProduces(
lambda: explore(
{
"a": {
"b": 4
},
"1": {
"2": {
"3": 100
}
},
},
{
"a": {
"b": [3, 4, 8, 10]
},
"1": {
"2": {
"3": lambda: random.choice([10, 100])
}
},
},
0.0,
lambda x: x),
{
"a": {
"b": {3, 8}
},
"1": {
"2": {
"3": {80, 120}
}
},
})
custom_explore_fn = MagicMock(side_effect=lambda x: x)
# Nested mutation and spec
assertNestedProduces(
lambda: explore(
{
"a": {
"b": 4
},
"1": {
"2": {
"3": 100
}
},
},
{
"a": {
"b": [3, 4, 8, 10]
},
"1": {
"2": {
"3": lambda: random.choice([10, 100])
}
},
},
0.0,
custom_explore_fn),
{
"a": {
"b": {3, 8}
},
"1": {
"2": {
"3": {80, 120}
}
},
})
# Expect call count to be 100 because we call explore 100 times
self.assertEqual(custom_explore_fn.call_count, 100)
def testYieldsTimeToOtherTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()