From 747b117929d5c5a0f01d5ce789544b64dc091a00 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Fri, 4 Jan 2019 13:51:11 -0800 Subject: [PATCH] [tune] Tweak/allow nested pbt mutations (#3455) * Fix warning text in pbt logger * Allow nested mutations in pbt by recursing explore function * Add test for nested pbt mutation * Update pbt explore to only call custom explore on top level * fix test --- python/ray/tune/schedulers/pbt.py | 11 +- python/ray/tune/test/trial_scheduler_test.py | 105 ++++++++++++++++++- 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 64dad97ca..d21ab5044 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -47,7 +47,12 @@ def explore(config, mutations, resample_probability, custom_explore_fn): """ new_config = copy.deepcopy(config) for key, distribution in mutations.items(): - if isinstance(distribution, list): + if isinstance(distribution, dict): + new_config.update({ + key: explore(config[key], mutations[key], resample_probability, + None) + }) + elif isinstance(distribution, list): if random.random() < resample_probability or \ config[key] not in distribution: new_config[key] = random.choice(distribution) @@ -213,8 +218,8 @@ class PopulationBasedTraining(FIFOScheduler): trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] if not new_state.last_checkpoint: - logger.warning("[pbt]: no checkpoint for trial" - "skip exploit for Trial {}".format(trial)) + logger.warning("[pbt]: no checkpoint for trial." + " Skip exploit for Trial {}".format(trial)) return new_config = explore(trial_to_clone.config, self._hyperparam_mutations, self._resample_probability, diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 9d463850f..b5426bb3d 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -5,7 +5,7 @@ from __future__ import print_function import random import unittest import numpy as np - +import sys import ray from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, @@ -17,6 +17,11 @@ from ray.tune.trial_executor import TrialExecutor from ray.rllib import _register_all _register_all() +if sys.version_info >= (3, 3): + from unittest.mock import MagicMock +else: + from mock import MagicMock + def result(t, rew): return dict( @@ -748,6 +753,104 @@ class PopulationBasedTestingSuite(unittest.TestCase): lambda x: x), {10.0, 100.0}) + def deep_add(seen, new_values): + for k, new_value in new_values.items(): + if isinstance(new_value, dict): + if k not in seen: + seen[k] = {} + seen[k].update(deep_add(seen[k], new_value)) + else: + if k not in seen: + seen[k] = set() + seen[k].add(new_value) + + return seen + + def assertNestedProduces(fn, values): + random.seed(0) + seen = {} + for _ in range(100): + new_config = fn() + seen = deep_add(seen, new_config) + self.assertEqual(seen, values) + + # Nested mutation and spec + assertNestedProduces( + lambda: explore( + { + "a": { + "b": 4 + }, + "1": { + "2": { + "3": 100 + } + }, + }, + { + "a": { + "b": [3, 4, 8, 10] + }, + "1": { + "2": { + "3": lambda: random.choice([10, 100]) + } + }, + }, + 0.0, + lambda x: x), + { + "a": { + "b": {3, 8} + }, + "1": { + "2": { + "3": {80, 120} + } + }, + }) + + custom_explore_fn = MagicMock(side_effect=lambda x: x) + + # Nested mutation and spec + assertNestedProduces( + lambda: explore( + { + "a": { + "b": 4 + }, + "1": { + "2": { + "3": 100 + } + }, + }, + { + "a": { + "b": [3, 4, 8, 10] + }, + "1": { + "2": { + "3": lambda: random.choice([10, 100]) + } + }, + }, + 0.0, + custom_explore_fn), + { + "a": { + "b": {3, 8} + }, + "1": { + "2": { + "3": {80, 120} + } + }, + }) + + # Expect call count to be 100 because we call explore 100 times + self.assertEqual(custom_explore_fn.call_count, 100) + def testYieldsTimeToOtherTrials(self): pbt, runner = self.basicSetup() trials = runner.get_trials()