[tune] Document trainable attributes and enable user-checkpoint… (#4868)

This commit is contained in:
Richard Liaw 2019-07-10 18:51:11 -07:00 committed by GitHub
parent e6a81d40a5
commit 691c9733f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 20 deletions

View file

@ -6,6 +6,7 @@ import os
import pickle
import numpy as np
from ray.tune import result as tune_result
from ray.rllib.agents.trainer import Trainer, with_common_config
@ -18,6 +19,7 @@ class _MockTrainer(Trainer):
"persistent_error": False,
"test_variable": 1,
"num_workers": 0,
"user_checkpoint_freq": 0,
})
@classmethod
@ -32,11 +34,15 @@ class _MockTrainer(Trainer):
if self.config["mock_error"] and self.iteration == 1 \
and (self.config["persistent_error"] or not self.restored):
raise Exception("mock error")
return dict(
result = dict(
episode_reward_mean=10,
episode_len_mean=10,
timesteps_this_iter=10,
info={})
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
if self.iteration % self.config["user_checkpoint_freq"] == 0:
result.update({tune_result.SHOULD_CHECKPOINT: True})
return result
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")

View file

@ -579,12 +579,6 @@ class Trainer(Trainable):
else:
return res[0] # backwards compatibility
@property
def iteration(self):
"""Current training iter, auto-incremented with each train() call."""
return self._iteration
@property
def _name(self):
"""Subclasses should override this to declare their name."""

View file

@ -9,6 +9,9 @@ import os
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
DONE = "done"
# (Optional) Enum for user controlled checkpoint
SHOULD_CHECKPOINT = "should_checkpoint"
# (Auto-filled) The hostname of the machine hosting the training process.
HOSTNAME = "hostname"

View file

@ -2237,6 +2237,29 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEquals(count_checkpoints(tmpdir), 2)
shutil.rmtree(tmpdir)
def testUserCheckpoint(self):
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
runner.step() # 0
self.assertFalse(trials[0].has_checkpoint())
runner.step() # 1
self.assertFalse(trials[0].has_checkpoint())
runner.step() # 2
self.assertTrue(trials[0].has_checkpoint())
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
runner2.step()
trials2 = runner2.get_trials()
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
shutil.rmtree(tmpdir)
class SearchAlgorithmTest(unittest.TestCase):
def testNestedSuggestion(self):

View file

@ -46,6 +46,11 @@ class Trainable(object):
just a ``my_train(config, reporter)`` function to the config.
The function will be automatically converted to this interface
(sans checkpoint functionality).
When using Tune, Tune will convert this class into a Ray actor, which
runs on a separate process. Tune will also change the current working
directory of this process to `self.logdir`.
"""
def __init__(self, config=None, logger_creator=None):
@ -70,14 +75,15 @@ class Trainable(object):
if logger_creator:
self._result_logger = logger_creator(self.config)
self.logdir = self._result_logger.logdir
self._logdir = self._result_logger.logdir
else:
logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
self.logdir = tempfile.mkdtemp(
self._logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
self._result_logger = UnifiedLogger(self.config, self.logdir, None)
self._result_logger = UnifiedLogger(self.config, self._logdir,
None)
self._iteration = 0
self._time_total = 0.0
@ -131,7 +137,8 @@ class Trainable(object):
across checkpoint / restore calls.
`training_iteration` (int): The index of this
training iteration, e.g. call to train().
training iteration, e.g. call to train(). This is incremented
after `_train()` is called.
`pid` (str): The pid of the training process.
@ -219,8 +226,8 @@ class Trainable(object):
def delete_checkpoint(self, checkpoint_dir):
"""Removes subdirectory within checkpoint_folder
Parameters
----------
Args:
checkpoint_dir : path to checkpoint
"""
if os.path.isfile(checkpoint_dir):
@ -275,8 +282,9 @@ class Trainable(object):
return checkpoint_path
def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
"""Saves the current model state to a Python object.
It also saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
@ -394,11 +402,45 @@ class Trainable(object):
self._result_logger.close()
self._stop()
@property
def logdir(self):
"""Directory of the results and checkpoints for this Trainable.
Tune will automatically sync this folder with the driver if execution
is distributed.
Note that the current working directory will also be changed to this.
"""
return self._logdir
@property
def iteration(self):
"""Current training iteration.
This value is automatically incremented every time `train()` is called
and is automatically inserted into the training result dict.
"""
return self._iteration
def get_config(self):
"""Returns configuration passed in by Tune."""
return self.config
def _train(self):
"""Subclasses should override this to implement train().
The return value will be automatically passed to the loggers. Users
can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT`
to manually trigger termination of this trial or checkpointing of this
trial. Note that manual checkpointing only works when subclassing
Trainables.
Returns:
A dict that describes training progress."""
A dict that describes training progress.
"""
raise NotImplementedError

View file

@ -15,7 +15,8 @@ import traceback
import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
from ray.tune.syncer import get_syncer
from ray.tune.trial import Trial, Checkpoint
from ray.tune.sample import function
@ -529,7 +530,8 @@ class TrialRunner(object):
# the scheduler decision is STOP or PAUSE. Note that
# PAUSE only checkpoints to memory and does not update
# the global checkpoint state.
self._checkpoint_trial_if_needed(trial)
self._checkpoint_trial_if_needed(
trial, force=result.get(SHOULD_CHECKPOINT, False))
if decision == TrialScheduler.CONTINUE:
self.trial_executor.continue_training(trial)
@ -554,9 +556,9 @@ class TrialRunner(object):
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
def _checkpoint_trial_if_needed(self, trial):
def _checkpoint_trial_if_needed(self, trial, force=False):
"""Checkpoints trial based off trial.last_result."""
if trial.should_checkpoint():
if trial.should_checkpoint() or force:
# Save trial runtime if possible
if hasattr(trial, "runner") and trial.runner:
self.trial_executor.save(trial, storage=Checkpoint.DISK)