mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Document trainable attributes and enable user-checkpoint… (#4868)
This commit is contained in:
parent
e6a81d40a5
commit
691c9733f9
6 changed files with 90 additions and 20 deletions
|
@ -6,6 +6,7 @@ import os
|
|||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from ray.tune import result as tune_result
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
|
||||
|
||||
|
@ -18,6 +19,7 @@ class _MockTrainer(Trainer):
|
|||
"persistent_error": False,
|
||||
"test_variable": 1,
|
||||
"num_workers": 0,
|
||||
"user_checkpoint_freq": 0,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
|
@ -32,11 +34,15 @@ class _MockTrainer(Trainer):
|
|||
if self.config["mock_error"] and self.iteration == 1 \
|
||||
and (self.config["persistent_error"] or not self.restored):
|
||||
raise Exception("mock error")
|
||||
return dict(
|
||||
result = dict(
|
||||
episode_reward_mean=10,
|
||||
episode_len_mean=10,
|
||||
timesteps_this_iter=10,
|
||||
info={})
|
||||
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
|
||||
if self.iteration % self.config["user_checkpoint_freq"] == 0:
|
||||
result.update({tune_result.SHOULD_CHECKPOINT: True})
|
||||
return result
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||
|
|
|
@ -579,12 +579,6 @@ class Trainer(Trainable):
|
|||
else:
|
||||
return res[0] # backwards compatibility
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""Current training iter, auto-incremented with each train() call."""
|
||||
|
||||
return self._iteration
|
||||
|
||||
@property
|
||||
def _name(self):
|
||||
"""Subclasses should override this to declare their name."""
|
||||
|
|
|
@ -9,6 +9,9 @@ import os
|
|||
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
|
||||
DONE = "done"
|
||||
|
||||
# (Optional) Enum for user controlled checkpoint
|
||||
SHOULD_CHECKPOINT = "should_checkpoint"
|
||||
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
HOSTNAME = "hostname"
|
||||
|
||||
|
|
|
@ -2237,6 +2237,29 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testUserCheckpoint(self):
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
runner.step() # 0
|
||||
self.assertFalse(trials[0].has_checkpoint())
|
||||
runner.step() # 1
|
||||
self.assertFalse(trials[0].has_checkpoint())
|
||||
runner.step() # 2
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
runner2.step()
|
||||
trials2 = runner2.get_trials()
|
||||
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
def testNestedSuggestion(self):
|
||||
|
|
|
@ -46,6 +46,11 @@ class Trainable(object):
|
|||
just a ``my_train(config, reporter)`` function to the config.
|
||||
The function will be automatically converted to this interface
|
||||
(sans checkpoint functionality).
|
||||
|
||||
When using Tune, Tune will convert this class into a Ray actor, which
|
||||
runs on a separate process. Tune will also change the current working
|
||||
directory of this process to `self.logdir`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, logger_creator=None):
|
||||
|
@ -70,14 +75,15 @@ class Trainable(object):
|
|||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
self.logdir = self._result_logger.logdir
|
||||
self._logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
self.logdir = tempfile.mkdtemp(
|
||||
self._logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
|
||||
self._result_logger = UnifiedLogger(self.config, self._logdir,
|
||||
None)
|
||||
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
|
@ -131,7 +137,8 @@ class Trainable(object):
|
|||
across checkpoint / restore calls.
|
||||
|
||||
`training_iteration` (int): The index of this
|
||||
training iteration, e.g. call to train().
|
||||
training iteration, e.g. call to train(). This is incremented
|
||||
after `_train()` is called.
|
||||
|
||||
`pid` (str): The pid of the training process.
|
||||
|
||||
|
@ -219,8 +226,8 @@ class Trainable(object):
|
|||
|
||||
def delete_checkpoint(self, checkpoint_dir):
|
||||
"""Removes subdirectory within checkpoint_folder
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Args:
|
||||
checkpoint_dir : path to checkpoint
|
||||
"""
|
||||
if os.path.isfile(checkpoint_dir):
|
||||
|
@ -275,8 +282,9 @@ class Trainable(object):
|
|||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
"""Saves the current model state to a Python object. It also
|
||||
saves to disk but does not return the checkpoint path.
|
||||
"""Saves the current model state to a Python object.
|
||||
|
||||
It also saves to disk but does not return the checkpoint path.
|
||||
|
||||
Returns:
|
||||
Object holding checkpoint data.
|
||||
|
@ -394,11 +402,45 @@ class Trainable(object):
|
|||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
"""Directory of the results and checkpoints for this Trainable.
|
||||
|
||||
Tune will automatically sync this folder with the driver if execution
|
||||
is distributed.
|
||||
|
||||
Note that the current working directory will also be changed to this.
|
||||
|
||||
"""
|
||||
return self._logdir
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
"""Current training iteration.
|
||||
|
||||
This value is automatically incremented every time `train()` is called
|
||||
and is automatically inserted into the training result dict.
|
||||
|
||||
"""
|
||||
return self._iteration
|
||||
|
||||
def get_config(self):
|
||||
"""Returns configuration passed in by Tune."""
|
||||
return self.config
|
||||
|
||||
def _train(self):
|
||||
"""Subclasses should override this to implement train().
|
||||
|
||||
The return value will be automatically passed to the loggers. Users
|
||||
can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT`
|
||||
to manually trigger termination of this trial or checkpointing of this
|
||||
trial. Note that manual checkpointing only works when subclassing
|
||||
Trainables.
|
||||
|
||||
Returns:
|
||||
A dict that describes training progress."""
|
||||
A dict that describes training progress.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@ import traceback
|
|||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
from ray.tune.syncer import get_syncer
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.sample import function
|
||||
|
@ -529,7 +530,8 @@ class TrialRunner(object):
|
|||
# the scheduler decision is STOP or PAUSE. Note that
|
||||
# PAUSE only checkpoints to memory and does not update
|
||||
# the global checkpoint state.
|
||||
self._checkpoint_trial_if_needed(trial)
|
||||
self._checkpoint_trial_if_needed(
|
||||
trial, force=result.get(SHOULD_CHECKPOINT, False))
|
||||
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
self.trial_executor.continue_training(trial)
|
||||
|
@ -554,9 +556,9 @@ class TrialRunner(object):
|
|||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _checkpoint_trial_if_needed(self, trial):
|
||||
def _checkpoint_trial_if_needed(self, trial, force=False):
|
||||
"""Checkpoints trial based off trial.last_result."""
|
||||
if trial.should_checkpoint():
|
||||
if trial.should_checkpoint() or force:
|
||||
# Save trial runtime if possible
|
||||
if hasattr(trial, "runner") and trial.runner:
|
||||
self.trial_executor.save(trial, storage=Checkpoint.DISK)
|
||||
|
|
Loading…
Add table
Reference in a new issue