mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] Document trainable attributes and enable user-checkpoint… (#4868)
This commit is contained in:
parent
e6a81d40a5
commit
691c9733f9
6 changed files with 90 additions and 20 deletions
|
@ -6,6 +6,7 @@ import os
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ray.tune import result as tune_result
|
||||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,6 +19,7 @@ class _MockTrainer(Trainer):
|
||||||
"persistent_error": False,
|
"persistent_error": False,
|
||||||
"test_variable": 1,
|
"test_variable": 1,
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
|
"user_checkpoint_freq": 0,
|
||||||
})
|
})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -32,11 +34,15 @@ class _MockTrainer(Trainer):
|
||||||
if self.config["mock_error"] and self.iteration == 1 \
|
if self.config["mock_error"] and self.iteration == 1 \
|
||||||
and (self.config["persistent_error"] or not self.restored):
|
and (self.config["persistent_error"] or not self.restored):
|
||||||
raise Exception("mock error")
|
raise Exception("mock error")
|
||||||
return dict(
|
result = dict(
|
||||||
episode_reward_mean=10,
|
episode_reward_mean=10,
|
||||||
episode_len_mean=10,
|
episode_len_mean=10,
|
||||||
timesteps_this_iter=10,
|
timesteps_this_iter=10,
|
||||||
info={})
|
info={})
|
||||||
|
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
|
||||||
|
if self.iteration % self.config["user_checkpoint_freq"] == 0:
|
||||||
|
result.update({tune_result.SHOULD_CHECKPOINT: True})
|
||||||
|
return result
|
||||||
|
|
||||||
def _save(self, checkpoint_dir):
|
def _save(self, checkpoint_dir):
|
||||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||||
|
|
|
@ -579,12 +579,6 @@ class Trainer(Trainable):
|
||||||
else:
|
else:
|
||||||
return res[0] # backwards compatibility
|
return res[0] # backwards compatibility
|
||||||
|
|
||||||
@property
|
|
||||||
def iteration(self):
|
|
||||||
"""Current training iter, auto-incremented with each train() call."""
|
|
||||||
|
|
||||||
return self._iteration
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _name(self):
|
def _name(self):
|
||||||
"""Subclasses should override this to declare their name."""
|
"""Subclasses should override this to declare their name."""
|
||||||
|
|
|
@ -9,6 +9,9 @@ import os
|
||||||
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
|
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
|
||||||
DONE = "done"
|
DONE = "done"
|
||||||
|
|
||||||
|
# (Optional) Enum for user controlled checkpoint
|
||||||
|
SHOULD_CHECKPOINT = "should_checkpoint"
|
||||||
|
|
||||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||||
HOSTNAME = "hostname"
|
HOSTNAME = "hostname"
|
||||||
|
|
||||||
|
|
|
@ -2237,6 +2237,29 @@ class TrialRunnerTest(unittest.TestCase):
|
||||||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
def testUserCheckpoint(self):
|
||||||
|
ray.init(num_cpus=3)
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||||
|
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
|
||||||
|
trials = runner.get_trials()
|
||||||
|
|
||||||
|
runner.step()
|
||||||
|
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||||
|
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||||
|
runner.step() # 0
|
||||||
|
self.assertFalse(trials[0].has_checkpoint())
|
||||||
|
runner.step() # 1
|
||||||
|
self.assertFalse(trials[0].has_checkpoint())
|
||||||
|
runner.step() # 2
|
||||||
|
self.assertTrue(trials[0].has_checkpoint())
|
||||||
|
|
||||||
|
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||||
|
runner2.step()
|
||||||
|
trials2 = runner2.get_trials()
|
||||||
|
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
class SearchAlgorithmTest(unittest.TestCase):
|
class SearchAlgorithmTest(unittest.TestCase):
|
||||||
def testNestedSuggestion(self):
|
def testNestedSuggestion(self):
|
||||||
|
|
|
@ -46,6 +46,11 @@ class Trainable(object):
|
||||||
just a ``my_train(config, reporter)`` function to the config.
|
just a ``my_train(config, reporter)`` function to the config.
|
||||||
The function will be automatically converted to this interface
|
The function will be automatically converted to this interface
|
||||||
(sans checkpoint functionality).
|
(sans checkpoint functionality).
|
||||||
|
|
||||||
|
When using Tune, Tune will convert this class into a Ray actor, which
|
||||||
|
runs on a separate process. Tune will also change the current working
|
||||||
|
directory of this process to `self.logdir`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=None, logger_creator=None):
|
def __init__(self, config=None, logger_creator=None):
|
||||||
|
@ -70,14 +75,15 @@ class Trainable(object):
|
||||||
|
|
||||||
if logger_creator:
|
if logger_creator:
|
||||||
self._result_logger = logger_creator(self.config)
|
self._result_logger = logger_creator(self.config)
|
||||||
self.logdir = self._result_logger.logdir
|
self._logdir = self._result_logger.logdir
|
||||||
else:
|
else:
|
||||||
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||||
self.logdir = tempfile.mkdtemp(
|
self._logdir = tempfile.mkdtemp(
|
||||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||||
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
|
self._result_logger = UnifiedLogger(self.config, self._logdir,
|
||||||
|
None)
|
||||||
|
|
||||||
self._iteration = 0
|
self._iteration = 0
|
||||||
self._time_total = 0.0
|
self._time_total = 0.0
|
||||||
|
@ -131,7 +137,8 @@ class Trainable(object):
|
||||||
across checkpoint / restore calls.
|
across checkpoint / restore calls.
|
||||||
|
|
||||||
`training_iteration` (int): The index of this
|
`training_iteration` (int): The index of this
|
||||||
training iteration, e.g. call to train().
|
training iteration, e.g. call to train(). This is incremented
|
||||||
|
after `_train()` is called.
|
||||||
|
|
||||||
`pid` (str): The pid of the training process.
|
`pid` (str): The pid of the training process.
|
||||||
|
|
||||||
|
@ -219,8 +226,8 @@ class Trainable(object):
|
||||||
|
|
||||||
def delete_checkpoint(self, checkpoint_dir):
|
def delete_checkpoint(self, checkpoint_dir):
|
||||||
"""Removes subdirectory within checkpoint_folder
|
"""Removes subdirectory within checkpoint_folder
|
||||||
Parameters
|
|
||||||
----------
|
Args:
|
||||||
checkpoint_dir : path to checkpoint
|
checkpoint_dir : path to checkpoint
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(checkpoint_dir):
|
if os.path.isfile(checkpoint_dir):
|
||||||
|
@ -275,8 +282,9 @@ class Trainable(object):
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
|
|
||||||
def save_to_object(self):
|
def save_to_object(self):
|
||||||
"""Saves the current model state to a Python object. It also
|
"""Saves the current model state to a Python object.
|
||||||
saves to disk but does not return the checkpoint path.
|
|
||||||
|
It also saves to disk but does not return the checkpoint path.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Object holding checkpoint data.
|
Object holding checkpoint data.
|
||||||
|
@ -394,11 +402,45 @@ class Trainable(object):
|
||||||
self._result_logger.close()
|
self._result_logger.close()
|
||||||
self._stop()
|
self._stop()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logdir(self):
|
||||||
|
"""Directory of the results and checkpoints for this Trainable.
|
||||||
|
|
||||||
|
Tune will automatically sync this folder with the driver if execution
|
||||||
|
is distributed.
|
||||||
|
|
||||||
|
Note that the current working directory will also be changed to this.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._logdir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iteration(self):
|
||||||
|
"""Current training iteration.
|
||||||
|
|
||||||
|
This value is automatically incremented every time `train()` is called
|
||||||
|
and is automatically inserted into the training result dict.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._iteration
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
"""Returns configuration passed in by Tune."""
|
||||||
|
return self.config
|
||||||
|
|
||||||
def _train(self):
|
def _train(self):
|
||||||
"""Subclasses should override this to implement train().
|
"""Subclasses should override this to implement train().
|
||||||
|
|
||||||
|
The return value will be automatically passed to the loggers. Users
|
||||||
|
can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT`
|
||||||
|
to manually trigger termination of this trial or checkpointing of this
|
||||||
|
trial. Note that manual checkpointing only works when subclassing
|
||||||
|
Trainables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict that describes training progress."""
|
A dict that describes training progress.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,8 @@ import traceback
|
||||||
import ray.cloudpickle as cloudpickle
|
import ray.cloudpickle as cloudpickle
|
||||||
from ray.tune import TuneError
|
from ray.tune import TuneError
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||||
|
SHOULD_CHECKPOINT)
|
||||||
from ray.tune.syncer import get_syncer
|
from ray.tune.syncer import get_syncer
|
||||||
from ray.tune.trial import Trial, Checkpoint
|
from ray.tune.trial import Trial, Checkpoint
|
||||||
from ray.tune.sample import function
|
from ray.tune.sample import function
|
||||||
|
@ -529,7 +530,8 @@ class TrialRunner(object):
|
||||||
# the scheduler decision is STOP or PAUSE. Note that
|
# the scheduler decision is STOP or PAUSE. Note that
|
||||||
# PAUSE only checkpoints to memory and does not update
|
# PAUSE only checkpoints to memory and does not update
|
||||||
# the global checkpoint state.
|
# the global checkpoint state.
|
||||||
self._checkpoint_trial_if_needed(trial)
|
self._checkpoint_trial_if_needed(
|
||||||
|
trial, force=result.get(SHOULD_CHECKPOINT, False))
|
||||||
|
|
||||||
if decision == TrialScheduler.CONTINUE:
|
if decision == TrialScheduler.CONTINUE:
|
||||||
self.trial_executor.continue_training(trial)
|
self.trial_executor.continue_training(trial)
|
||||||
|
@ -554,9 +556,9 @@ class TrialRunner(object):
|
||||||
self.trial_executor.stop_trial(
|
self.trial_executor.stop_trial(
|
||||||
trial, error=True, error_msg=error_msg)
|
trial, error=True, error_msg=error_msg)
|
||||||
|
|
||||||
def _checkpoint_trial_if_needed(self, trial):
|
def _checkpoint_trial_if_needed(self, trial, force=False):
|
||||||
"""Checkpoints trial based off trial.last_result."""
|
"""Checkpoints trial based off trial.last_result."""
|
||||||
if trial.should_checkpoint():
|
if trial.should_checkpoint() or force:
|
||||||
# Save trial runtime if possible
|
# Save trial runtime if possible
|
||||||
if hasattr(trial, "runner") and trial.runner:
|
if hasattr(trial, "runner") and trial.runner:
|
||||||
self.trial_executor.save(trial, storage=Checkpoint.DISK)
|
self.trial_executor.save(trial, storage=Checkpoint.DISK)
|
||||||
|
|
Loading…
Add table
Reference in a new issue