mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Tweak/allow nested pbt mutations (#3455)
* Fix warning text in pbt logger * Allow nested mutations in pbt by recursing explore function * Add test for nested pbt mutation * Update pbt explore to only call custom explore on top level * fix test
This commit is contained in:
parent
cd80891ddb
commit
747b117929
2 changed files with 112 additions and 4 deletions
|
@ -47,7 +47,12 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
|||
"""
|
||||
new_config = copy.deepcopy(config)
|
||||
for key, distribution in mutations.items():
|
||||
if isinstance(distribution, list):
|
||||
if isinstance(distribution, dict):
|
||||
new_config.update({
|
||||
key: explore(config[key], mutations[key], resample_probability,
|
||||
None)
|
||||
})
|
||||
elif isinstance(distribution, list):
|
||||
if random.random() < resample_probability or \
|
||||
config[key] not in distribution:
|
||||
new_config[key] = random.choice(distribution)
|
||||
|
@ -213,8 +218,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
trial_state = self._trial_state[trial]
|
||||
new_state = self._trial_state[trial_to_clone]
|
||||
if not new_state.last_checkpoint:
|
||||
logger.warning("[pbt]: no checkpoint for trial"
|
||||
"skip exploit for Trial {}".format(trial))
|
||||
logger.warning("[pbt]: no checkpoint for trial."
|
||||
" Skip exploit for Trial {}".format(trial))
|
||||
return
|
||||
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability,
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import random
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
import ray
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
|
@ -17,6 +17,11 @@ from ray.tune.trial_executor import TrialExecutor
|
|||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
from unittest.mock import MagicMock
|
||||
else:
|
||||
from mock import MagicMock
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return dict(
|
||||
|
@ -748,6 +753,104 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
lambda x: x),
|
||||
{10.0, 100.0})
|
||||
|
||||
def deep_add(seen, new_values):
|
||||
for k, new_value in new_values.items():
|
||||
if isinstance(new_value, dict):
|
||||
if k not in seen:
|
||||
seen[k] = {}
|
||||
seen[k].update(deep_add(seen[k], new_value))
|
||||
else:
|
||||
if k not in seen:
|
||||
seen[k] = set()
|
||||
seen[k].add(new_value)
|
||||
|
||||
return seen
|
||||
|
||||
def assertNestedProduces(fn, values):
|
||||
random.seed(0)
|
||||
seen = {}
|
||||
for _ in range(100):
|
||||
new_config = fn()
|
||||
seen = deep_add(seen, new_config)
|
||||
self.assertEqual(seen, values)
|
||||
|
||||
# Nested mutation and spec
|
||||
assertNestedProduces(
|
||||
lambda: explore(
|
||||
{
|
||||
"a": {
|
||||
"b": 4
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": 100
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"a": {
|
||||
"b": [3, 4, 8, 10]
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": lambda: random.choice([10, 100])
|
||||
}
|
||||
},
|
||||
},
|
||||
0.0,
|
||||
lambda x: x),
|
||||
{
|
||||
"a": {
|
||||
"b": {3, 8}
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": {80, 120}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
custom_explore_fn = MagicMock(side_effect=lambda x: x)
|
||||
|
||||
# Nested mutation and spec
|
||||
assertNestedProduces(
|
||||
lambda: explore(
|
||||
{
|
||||
"a": {
|
||||
"b": 4
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": 100
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"a": {
|
||||
"b": [3, 4, 8, 10]
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": lambda: random.choice([10, 100])
|
||||
}
|
||||
},
|
||||
},
|
||||
0.0,
|
||||
custom_explore_fn),
|
||||
{
|
||||
"a": {
|
||||
"b": {3, 8}
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": {80, 120}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
# Expect call count to be 100 because we call explore 100 times
|
||||
self.assertEqual(custom_explore_fn.call_count, 100)
|
||||
|
||||
def testYieldsTimeToOtherTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
|
Loading…
Add table
Reference in a new issue