mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Reset Config for Trainables (#2831)
Adds the ability for trainables to reset their configurations during experiments. These changes in particular add the base functions to the trial_executor and trainable interfaces as well as giving the basic implementation on the PopulationBasedTraining scheduler. Related issue number: #2741
This commit is contained in:
parent
5da6e78db1
commit
045861c9b0
7 changed files with 101 additions and 21 deletions
|
@ -53,6 +53,10 @@ class MyTrainableClass(Trainable):
|
|||
self.timestep = data["timestep"]
|
||||
self.current_value = data["value"]
|
||||
|
||||
def reset_config(self, new_config):
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -157,6 +157,25 @@ class RayTrialExecutor(TrialExecutor):
|
|||
self._paused[trial_future[0]] = trial
|
||||
super(RayTrialExecutor, self).pause_trial(trial)
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
"""Tries to invoke `Trainable.reset_config()` to reset trial.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be reset.
|
||||
new_config (dict): New configuration for Trial
|
||||
trainable.
|
||||
new_experiment_tag (str): New experiment name
|
||||
for trial.
|
||||
|
||||
Returns:
|
||||
True if `reset_config` is successful else False.
|
||||
"""
|
||||
trial.experiment_tag = new_experiment_tag
|
||||
trial.config = new_config
|
||||
trainable = trial.runner
|
||||
reset_val = ray.get(trainable.reset_config.remote(new_config))
|
||||
return reset_val
|
||||
|
||||
def get_running_trials(self):
|
||||
"""Returns the running trials."""
|
||||
|
||||
|
|
|
@ -221,12 +221,17 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
trial_state.last_score))
|
||||
# TODO(ekl) restarting the trial is expensive. We should implement a
|
||||
# lighter way reset() method that can alter the trial config.
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = make_experiment_tag(
|
||||
trial_state.orig_tag, new_config, self._hyperparam_mutations)
|
||||
trial_executor.start_trial(
|
||||
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
||||
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
|
||||
self._hyperparam_mutations)
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
new_tag)
|
||||
if not reset_successful:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial_executor.start_trial(
|
||||
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
||||
|
||||
self._num_perturbations += 1
|
||||
# Transfer over the last perturbation time as well
|
||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||
|
|
|
@ -7,6 +7,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
@ -21,20 +22,6 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def _get_trials(self):
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"bar": {
|
||||
"grid_search": [True, False]
|
||||
},
|
||||
"foo": {
|
||||
"grid_search": [1, 2, 3]
|
||||
},
|
||||
},
|
||||
}, "grid_search")
|
||||
return list(trials)
|
||||
|
||||
def testStartStop(self):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
|
@ -76,8 +63,43 @@ class RayTrialExecutorTest(unittest.TestCase):
|
|||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testNoResetTrial(self):
|
||||
"""Tests that reset handles NotImplemented properly."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
|
||||
self.assertEqual(exists, False)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
|
||||
def testResetTrial(self):
|
||||
"""Tests that reset works as expected."""
|
||||
|
||||
class B(Trainable):
|
||||
def _train(self):
|
||||
return dict(timesteps_this_iter=1, done=True)
|
||||
|
||||
def reset_config(self, config):
|
||||
self.config = config
|
||||
return True
|
||||
|
||||
trials = self.generate_trials({
|
||||
"run": B,
|
||||
"config": {
|
||||
"foo": 0
|
||||
},
|
||||
}, "grid_search")
|
||||
trial = trials[0]
|
||||
self.trial_executor.start_trial(trial)
|
||||
exists = self.trial_executor.reset_trial(trial, {"hi": 1},
|
||||
"modified_mock")
|
||||
self.assertEqual(exists, True)
|
||||
self.assertEqual(trial.config.get("hi"), 1)
|
||||
self.assertEqual(trial.experiment_tag, "modified_mock")
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
|
||||
def generate_trials(self, spec, name):
|
||||
suggester = BasicVariantGenerator({name: spec})
|
||||
suggester = BasicVariantGenerator()
|
||||
suggester.add_configurations({name: spec})
|
||||
return suggester.next_trials()
|
||||
|
||||
|
||||
|
|
|
@ -168,6 +168,9 @@ class _MockTrialExecutor(TrialExecutor):
|
|||
def save(self, trial, type=Checkpoint.DISK):
|
||||
return trial.trainable_name
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
return False
|
||||
|
||||
|
||||
class _MockTrialRunner():
|
||||
def __init__(self, scheduler):
|
||||
|
|
|
@ -273,6 +273,18 @@ class Trainable(object):
|
|||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def reset_config(self, new_config):
|
||||
"""Resets configuration without restarting the trial.
|
||||
|
||||
Args:
|
||||
new_config (dir): Updated hyperparameter configuration
|
||||
for the trainable.
|
||||
|
||||
Returns:
|
||||
True if configuration reset successfully else False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this trainable."""
|
||||
|
||||
|
|
|
@ -106,6 +106,21 @@ class TrialExecutor(object):
|
|||
assert trial.status == Trial.PAUSED, trial.status
|
||||
self.start_trial(trial)
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
"""Tries to invoke `Trainable.reset_config()` to reset trial.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to be reset.
|
||||
new_config (dict): New configuration for Trial
|
||||
trainable.
|
||||
new_experiment_tag (str): New experiment name
|
||||
for trial.
|
||||
|
||||
Returns:
|
||||
True if `reset_config` is successful else False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_running_trials(self):
|
||||
"""Returns all running trials."""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
|
|
Loading…
Add table
Reference in a new issue