[tune] Reset Config for Trainables (#2831)

Adds the ability for trainables to reset their configurations during experiments. These changes in particular add the base functions to the trial_executor and trainable interfaces as well as giving the basic implementation on the PopulationBasedTraining scheduler.

Related issue number: #2741
This commit is contained in:
Kaahan 2018-09-11 08:45:04 -07:00 committed by Richard Liaw
parent 5da6e78db1
commit 045861c9b0
7 changed files with 101 additions and 21 deletions

View file

@ -53,6 +53,10 @@ class MyTrainableClass(Trainable):
self.timestep = data["timestep"]
self.current_value = data["value"]
def reset_config(self, new_config):
self.config = new_config
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser()

View file

@ -157,6 +157,25 @@ class RayTrialExecutor(TrialExecutor):
self._paused[trial_future[0]] = trial
super(RayTrialExecutor, self).pause_trial(trial)
def reset_trial(self, trial, new_config, new_experiment_tag):
"""Tries to invoke `Trainable.reset_config()` to reset trial.
Args:
trial (Trial): Trial to be reset.
new_config (dict): New configuration for Trial
trainable.
new_experiment_tag (str): New experiment name
for trial.
Returns:
True if `reset_config` is successful else False.
"""
trial.experiment_tag = new_experiment_tag
trial.config = new_config
trainable = trial.runner
reset_val = ray.get(trainable.reset_config.remote(new_config))
return reset_val
def get_running_trials(self):
"""Returns the running trials."""

View file

@ -221,12 +221,17 @@ class PopulationBasedTraining(FIFOScheduler):
trial_state.last_score))
# TODO(ekl) restarting the trial is expensive. We should implement a
# lighter way reset() method that can alter the trial config.
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = make_experiment_tag(
trial_state.orig_tag, new_config, self._hyperparam_mutations)
trial_executor.start_trial(
trial, Checkpoint.from_object(new_state.last_checkpoint))
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
self._hyperparam_mutations)
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
if not reset_successful:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(
trial, Checkpoint.from_object(new_state.last_checkpoint))
self._num_perturbations += 1
# Transfer over the last perturbation time as well
trial_state.last_perturbation_time = new_state.last_perturbation_time

View file

@ -7,6 +7,7 @@ import unittest
import ray
from ray.rllib import _register_all
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.trial import Trial, Checkpoint
@ -21,20 +22,6 @@ class RayTrialExecutorTest(unittest.TestCase):
ray.shutdown()
_register_all() # re-register the evicted objects
def _get_trials(self):
trials = self.generate_trials({
"run": "PPO",
"config": {
"bar": {
"grid_search": [True, False]
},
"foo": {
"grid_search": [1, 2, 3]
},
},
}, "grid_search")
return list(trials)
def testStartStop(self):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
@ -76,8 +63,43 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testNoResetTrial(self):
"""Tests that reset handles NotImplemented properly."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
self.assertEqual(exists, False)
self.assertEqual(Trial.RUNNING, trial.status)
def testResetTrial(self):
"""Tests that reset works as expected."""
class B(Trainable):
def _train(self):
return dict(timesteps_this_iter=1, done=True)
def reset_config(self, config):
self.config = config
return True
trials = self.generate_trials({
"run": B,
"config": {
"foo": 0
},
}, "grid_search")
trial = trials[0]
self.trial_executor.start_trial(trial)
exists = self.trial_executor.reset_trial(trial, {"hi": 1},
"modified_mock")
self.assertEqual(exists, True)
self.assertEqual(trial.config.get("hi"), 1)
self.assertEqual(trial.experiment_tag, "modified_mock")
self.assertEqual(Trial.RUNNING, trial.status)
def generate_trials(self, spec, name):
suggester = BasicVariantGenerator({name: spec})
suggester = BasicVariantGenerator()
suggester.add_configurations({name: spec})
return suggester.next_trials()

View file

@ -168,6 +168,9 @@ class _MockTrialExecutor(TrialExecutor):
def save(self, trial, type=Checkpoint.DISK):
return trial.trainable_name
def reset_trial(self, trial, new_config, new_experiment_tag):
return False
class _MockTrialRunner():
def __init__(self, scheduler):

View file

@ -273,6 +273,18 @@ class Trainable(object):
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def reset_config(self, new_config):
"""Resets configuration without restarting the trial.
Args:
new_config (dir): Updated hyperparameter configuration
for the trainable.
Returns:
True if configuration reset successfully else False.
"""
return False
def stop(self):
"""Releases all resources used by this trainable."""

View file

@ -106,6 +106,21 @@ class TrialExecutor(object):
assert trial.status == Trial.PAUSED, trial.status
self.start_trial(trial)
def reset_trial(self, trial, new_config, new_experiment_tag):
"""Tries to invoke `Trainable.reset_config()` to reset trial.
Args:
trial (Trial): Trial to be reset.
new_config (dict): New configuration for Trial
trainable.
new_experiment_tag (str): New experiment name
for trial.
Returns:
True if `reset_config` is successful else False.
"""
raise NotImplementedError
def get_running_trials(self):
"""Returns all running trials."""
raise NotImplementedError("Subclasses of TrialExecutor must provide "