mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] Reset Config for Trainables (#2831)
Adds the ability for trainables to reset their configurations during experiments. These changes in particular add the base functions to the trial_executor and trainable interfaces as well as giving the basic implementation on the PopulationBasedTraining scheduler. Related issue number: #2741
This commit is contained in:
parent
5da6e78db1
commit
045861c9b0
7 changed files with 101 additions and 21 deletions
|
@ -53,6 +53,10 @@ class MyTrainableClass(Trainable):
|
||||||
self.timestep = data["timestep"]
|
self.timestep = data["timestep"]
|
||||||
self.current_value = data["value"]
|
self.current_value = data["value"]
|
||||||
|
|
||||||
|
def reset_config(self, new_config):
|
||||||
|
self.config = new_config
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -157,6 +157,25 @@ class RayTrialExecutor(TrialExecutor):
|
||||||
self._paused[trial_future[0]] = trial
|
self._paused[trial_future[0]] = trial
|
||||||
super(RayTrialExecutor, self).pause_trial(trial)
|
super(RayTrialExecutor, self).pause_trial(trial)
|
||||||
|
|
||||||
|
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||||
|
"""Tries to invoke `Trainable.reset_config()` to reset trial.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trial (Trial): Trial to be reset.
|
||||||
|
new_config (dict): New configuration for Trial
|
||||||
|
trainable.
|
||||||
|
new_experiment_tag (str): New experiment name
|
||||||
|
for trial.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `reset_config` is successful else False.
|
||||||
|
"""
|
||||||
|
trial.experiment_tag = new_experiment_tag
|
||||||
|
trial.config = new_config
|
||||||
|
trainable = trial.runner
|
||||||
|
reset_val = ray.get(trainable.reset_config.remote(new_config))
|
||||||
|
return reset_val
|
||||||
|
|
||||||
def get_running_trials(self):
|
def get_running_trials(self):
|
||||||
"""Returns the running trials."""
|
"""Returns the running trials."""
|
||||||
|
|
||||||
|
|
|
@ -221,12 +221,17 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||||
trial_state.last_score))
|
trial_state.last_score))
|
||||||
# TODO(ekl) restarting the trial is expensive. We should implement a
|
# TODO(ekl) restarting the trial is expensive. We should implement a
|
||||||
# lighter way reset() method that can alter the trial config.
|
# lighter way reset() method that can alter the trial config.
|
||||||
trial_executor.stop_trial(trial, stop_logger=False)
|
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
|
||||||
trial.config = new_config
|
self._hyperparam_mutations)
|
||||||
trial.experiment_tag = make_experiment_tag(
|
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||||
trial_state.orig_tag, new_config, self._hyperparam_mutations)
|
new_tag)
|
||||||
trial_executor.start_trial(
|
if not reset_successful:
|
||||||
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
trial_executor.stop_trial(trial, stop_logger=False)
|
||||||
|
trial.config = new_config
|
||||||
|
trial.experiment_tag = new_tag
|
||||||
|
trial_executor.start_trial(
|
||||||
|
trial, Checkpoint.from_object(new_state.last_checkpoint))
|
||||||
|
|
||||||
self._num_perturbations += 1
|
self._num_perturbations += 1
|
||||||
# Transfer over the last perturbation time as well
|
# Transfer over the last perturbation time as well
|
||||||
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
trial_state.last_perturbation_time = new_state.last_perturbation_time
|
||||||
|
|
|
@ -7,6 +7,7 @@ import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
|
from ray.tune import Trainable
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||||
from ray.tune.suggest import BasicVariantGenerator
|
from ray.tune.suggest import BasicVariantGenerator
|
||||||
from ray.tune.trial import Trial, Checkpoint
|
from ray.tune.trial import Trial, Checkpoint
|
||||||
|
@ -21,20 +22,6 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
_register_all() # re-register the evicted objects
|
_register_all() # re-register the evicted objects
|
||||||
|
|
||||||
def _get_trials(self):
|
|
||||||
trials = self.generate_trials({
|
|
||||||
"run": "PPO",
|
|
||||||
"config": {
|
|
||||||
"bar": {
|
|
||||||
"grid_search": [True, False]
|
|
||||||
},
|
|
||||||
"foo": {
|
|
||||||
"grid_search": [1, 2, 3]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, "grid_search")
|
|
||||||
return list(trials)
|
|
||||||
|
|
||||||
def testStartStop(self):
|
def testStartStop(self):
|
||||||
trial = Trial("__fake")
|
trial = Trial("__fake")
|
||||||
self.trial_executor.start_trial(trial)
|
self.trial_executor.start_trial(trial)
|
||||||
|
@ -76,8 +63,43 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||||
self.trial_executor.stop_trial(trial)
|
self.trial_executor.stop_trial(trial)
|
||||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||||
|
|
||||||
|
def testNoResetTrial(self):
|
||||||
|
"""Tests that reset handles NotImplemented properly."""
|
||||||
|
trial = Trial("__fake")
|
||||||
|
self.trial_executor.start_trial(trial)
|
||||||
|
exists = self.trial_executor.reset_trial(trial, {}, "modified_mock")
|
||||||
|
self.assertEqual(exists, False)
|
||||||
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
|
||||||
|
def testResetTrial(self):
|
||||||
|
"""Tests that reset works as expected."""
|
||||||
|
|
||||||
|
class B(Trainable):
|
||||||
|
def _train(self):
|
||||||
|
return dict(timesteps_this_iter=1, done=True)
|
||||||
|
|
||||||
|
def reset_config(self, config):
|
||||||
|
self.config = config
|
||||||
|
return True
|
||||||
|
|
||||||
|
trials = self.generate_trials({
|
||||||
|
"run": B,
|
||||||
|
"config": {
|
||||||
|
"foo": 0
|
||||||
|
},
|
||||||
|
}, "grid_search")
|
||||||
|
trial = trials[0]
|
||||||
|
self.trial_executor.start_trial(trial)
|
||||||
|
exists = self.trial_executor.reset_trial(trial, {"hi": 1},
|
||||||
|
"modified_mock")
|
||||||
|
self.assertEqual(exists, True)
|
||||||
|
self.assertEqual(trial.config.get("hi"), 1)
|
||||||
|
self.assertEqual(trial.experiment_tag, "modified_mock")
|
||||||
|
self.assertEqual(Trial.RUNNING, trial.status)
|
||||||
|
|
||||||
def generate_trials(self, spec, name):
|
def generate_trials(self, spec, name):
|
||||||
suggester = BasicVariantGenerator({name: spec})
|
suggester = BasicVariantGenerator()
|
||||||
|
suggester.add_configurations({name: spec})
|
||||||
return suggester.next_trials()
|
return suggester.next_trials()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -168,6 +168,9 @@ class _MockTrialExecutor(TrialExecutor):
|
||||||
def save(self, trial, type=Checkpoint.DISK):
|
def save(self, trial, type=Checkpoint.DISK):
|
||||||
return trial.trainable_name
|
return trial.trainable_name
|
||||||
|
|
||||||
|
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _MockTrialRunner():
|
class _MockTrialRunner():
|
||||||
def __init__(self, scheduler):
|
def __init__(self, scheduler):
|
||||||
|
|
|
@ -273,6 +273,18 @@ class Trainable(object):
|
||||||
self.restore(checkpoint_path)
|
self.restore(checkpoint_path)
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def reset_config(self, new_config):
|
||||||
|
"""Resets configuration without restarting the trial.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_config (dir): Updated hyperparameter configuration
|
||||||
|
for the trainable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if configuration reset successfully else False.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Releases all resources used by this trainable."""
|
"""Releases all resources used by this trainable."""
|
||||||
|
|
||||||
|
|
|
@ -106,6 +106,21 @@ class TrialExecutor(object):
|
||||||
assert trial.status == Trial.PAUSED, trial.status
|
assert trial.status == Trial.PAUSED, trial.status
|
||||||
self.start_trial(trial)
|
self.start_trial(trial)
|
||||||
|
|
||||||
|
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||||
|
"""Tries to invoke `Trainable.reset_config()` to reset trial.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trial (Trial): Trial to be reset.
|
||||||
|
new_config (dict): New configuration for Trial
|
||||||
|
trainable.
|
||||||
|
new_experiment_tag (str): New experiment name
|
||||||
|
for trial.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `reset_config` is successful else False.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_running_trials(self):
|
def get_running_trials(self):
|
||||||
"""Returns all running trials."""
|
"""Returns all running trials."""
|
||||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||||
|
|
Loading…
Add table
Reference in a new issue