diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index b6532144f..c958d2512 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -53,6 +53,10 @@ class MyTrainableClass(Trainable): self.timestep = data["timestep"] self.current_value = data["value"] + def reset_config(self, new_config): + self.config = new_config + return True + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 1d25fc2b9..196240ae2 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -157,6 +157,25 @@ class RayTrialExecutor(TrialExecutor): self._paused[trial_future[0]] = trial super(RayTrialExecutor, self).pause_trial(trial) + def reset_trial(self, trial, new_config, new_experiment_tag): + """Tries to invoke `Trainable.reset_config()` to reset trial. + + Args: + trial (Trial): Trial to be reset. + new_config (dict): New configuration for Trial + trainable. + new_experiment_tag (str): New experiment name + for trial. + + Returns: + True if `reset_config` is successful else False. + """ + trial.experiment_tag = new_experiment_tag + trial.config = new_config + trainable = trial.runner + reset_val = ray.get(trainable.reset_config.remote(new_config)) + return reset_val + def get_running_trials(self): """Returns the running trials.""" diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 2e8d97112..4fcce0885 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -221,12 +221,17 @@ class PopulationBasedTraining(FIFOScheduler): trial_state.last_score)) # TODO(ekl) restarting the trial is expensive. We should implement a # lighter way reset() method that can alter the trial config. - trial_executor.stop_trial(trial, stop_logger=False) - trial.config = new_config - trial.experiment_tag = make_experiment_tag( - trial_state.orig_tag, new_config, self._hyperparam_mutations) - trial_executor.start_trial( - trial, Checkpoint.from_object(new_state.last_checkpoint)) + new_tag = make_experiment_tag(trial_state.orig_tag, new_config, + self._hyperparam_mutations) + reset_successful = trial_executor.reset_trial(trial, new_config, + new_tag) + if not reset_successful: + trial_executor.stop_trial(trial, stop_logger=False) + trial.config = new_config + trial.experiment_tag = new_tag + trial_executor.start_trial( + trial, Checkpoint.from_object(new_state.last_checkpoint)) + self._num_perturbations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time diff --git a/python/ray/tune/test/ray_trial_executor_test.py b/python/ray/tune/test/ray_trial_executor_test.py index ddd5995ea..35c413e71 100644 --- a/python/ray/tune/test/ray_trial_executor_test.py +++ b/python/ray/tune/test/ray_trial_executor_test.py @@ -7,6 +7,7 @@ import unittest import ray from ray.rllib import _register_all +from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, Checkpoint @@ -21,20 +22,6 @@ class RayTrialExecutorTest(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects - def _get_trials(self): - trials = self.generate_trials({ - "run": "PPO", - "config": { - "bar": { - "grid_search": [True, False] - }, - "foo": { - "grid_search": [1, 2, 3] - }, - }, - }, "grid_search") - return list(trials) - def testStartStop(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) @@ -76,8 +63,43 @@ class RayTrialExecutorTest(unittest.TestCase): self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) + def testNoResetTrial(self): + """Tests that reset handles NotImplemented properly.""" + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") + self.assertEqual(exists, False) + self.assertEqual(Trial.RUNNING, trial.status) + + def testResetTrial(self): + """Tests that reset works as expected.""" + + class B(Trainable): + def _train(self): + return dict(timesteps_this_iter=1, done=True) + + def reset_config(self, config): + self.config = config + return True + + trials = self.generate_trials({ + "run": B, + "config": { + "foo": 0 + }, + }, "grid_search") + trial = trials[0] + self.trial_executor.start_trial(trial) + exists = self.trial_executor.reset_trial(trial, {"hi": 1}, + "modified_mock") + self.assertEqual(exists, True) + self.assertEqual(trial.config.get("hi"), 1) + self.assertEqual(trial.experiment_tag, "modified_mock") + self.assertEqual(Trial.RUNNING, trial.status) + def generate_trials(self, spec, name): - suggester = BasicVariantGenerator({name: spec}) + suggester = BasicVariantGenerator() + suggester.add_configurations({name: spec}) return suggester.next_trials() diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index c67219992..21aabec81 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -168,6 +168,9 @@ class _MockTrialExecutor(TrialExecutor): def save(self, trial, type=Checkpoint.DISK): return trial.trainable_name + def reset_trial(self, trial, new_config, new_experiment_tag): + return False + class _MockTrialRunner(): def __init__(self, scheduler): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index bbdfb69ae..815888295 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -273,6 +273,18 @@ class Trainable(object): self.restore(checkpoint_path) shutil.rmtree(tmpdir) + def reset_config(self, new_config): + """Resets configuration without restarting the trial. + + Args: + new_config (dir): Updated hyperparameter configuration + for the trainable. + + Returns: + True if configuration reset successfully else False. + """ + return False + def stop(self): """Releases all resources used by this trainable.""" diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 61bdf45a8..89a56d0f7 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -106,6 +106,21 @@ class TrialExecutor(object): assert trial.status == Trial.PAUSED, trial.status self.start_trial(trial) + def reset_trial(self, trial, new_config, new_experiment_tag): + """Tries to invoke `Trainable.reset_config()` to reset trial. + + Args: + trial (Trial): Trial to be reset. + new_config (dict): New configuration for Trial + trainable. + new_experiment_tag (str): New experiment name + for trial. + + Returns: + True if `reset_config` is successful else False. + """ + raise NotImplementedError + def get_running_trials(self): """Returns all running trials.""" raise NotImplementedError("Subclasses of TrialExecutor must provide "