[tune] custom trial directory name (#10214)

This commit is contained in:
Richard Liaw 2020-08-25 12:52:54 -07:00 committed by GitHub
parent 24a7a8a04d
commit 146d91385c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 91 additions and 21 deletions

View file

@ -189,6 +189,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
# str(None) doesn't create None # str(None) doesn't create None
restore_path=spec.get("restore"), restore_path=spec.get("restore"),
trial_name_creator=spec.get("trial_name_creator"), trial_name_creator=spec.get("trial_name_creator"),
trial_dirname_creator=spec.get("trial_dirname_creator"),
loggers=spec.get("loggers"), loggers=spec.get("loggers"),
log_to_file=spec.get("log_to_file"), log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None # str(None) doesn't create None

View file

@ -108,6 +108,7 @@ class Experiment:
local_dir=None, local_dir=None,
upload_dir=None, upload_dir=None,
trial_name_creator=None, trial_name_creator=None,
trial_dirname_creator=None,
loggers=None, loggers=None,
log_to_file=False, log_to_file=False,
sync_to_driver=None, sync_to_driver=None,
@ -173,6 +174,7 @@ class Experiment:
"upload_dir": upload_dir, "upload_dir": upload_dir,
"remote_checkpoint_dir": self.remote_checkpoint_dir, "remote_checkpoint_dir": self.remote_checkpoint_dir,
"trial_name_creator": trial_name_creator, "trial_name_creator": trial_name_creator,
"trial_dirname_creator": trial_dirname_creator,
"loggers": loggers, "loggers": loggers,
"log_to_file": (stdout_file, stderr_file), "log_to_file": (stdout_file, stderr_file),
"sync_to_driver": sync_to_driver, "sync_to_driver": sync_to_driver,

View file

@ -1,9 +1,9 @@
import pickle
import os import os
import copy import copy
import logging import logging
import glob import glob
import ray.cloudpickle as cloudpickle
from ray.tune.error import TuneError from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list from ray.tune.experiment import convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.config_parser import make_parser, create_trial_from_spec
@ -29,7 +29,7 @@ def _atomic_save(state, checkpoint_dir, file_name):
tmp_search_ckpt_path = os.path.join(checkpoint_dir, tmp_search_ckpt_path = os.path.join(checkpoint_dir,
".tmp_search_generator_ckpt") ".tmp_search_generator_ckpt")
with open(tmp_search_ckpt_path, "wb") as f: with open(tmp_search_ckpt_path, "wb") as f:
pickle.dump(state, f) cloudpickle.dump(state, f)
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name)) os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
@ -41,7 +41,7 @@ def _find_newest_ckpt(dirpath, pattern):
return return
most_recent_checkpoint = max(full_paths) most_recent_checkpoint = max(full_paths)
with open(most_recent_checkpoint, "rb") as f: with open(most_recent_checkpoint, "rb") as f:
search_alg_state = pickle.load(f) search_alg_state = cloudpickle.load(f)
return search_alg_state return search_alg_state

View file

@ -33,10 +33,12 @@ from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
class TrainableFunctionApiTest(unittest.TestCase): class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self): def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024) ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self): def tearDown(self):
ray.shutdown() ray.shutdown()
_register_all() # re-register the evicted objects _register_all() # re-register the evicted objects
shutil.rmtree(self.tmpdir)
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None): def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
"""Checks logging is the same between APIs. """Checks logging is the same between APIs.
@ -547,6 +549,44 @@ class TrainableFunctionApiTest(unittest.TestCase):
with self.assertRaises(TuneError): with self.assertRaises(TuneError):
tune.run(train, stop=stop) tune.run(train, stop=stop)
def testCustomTrialDir(self):
def train(config):
for i in range(10):
tune.report(test=i)
custom_name = "TRAIL_TRIAL"
def custom_trial_dir(trial):
return custom_name
trials = tune.run(
train,
config={
"t1": tune.grid_search([1, 2, 3])
},
trial_dirname_creator=custom_trial_dir,
local_dir=self.tmpdir).trials
logdirs = {t.logdir for t in trials}
assert len(logdirs) == 3
assert all(custom_name in dirpath for dirpath in logdirs)
def testTrialDirRegression(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
trials = tune.run(
train,
config={
"t1": tune.grid_search([1, 2, 3])
},
local_dir=self.tmpdir).trials
logdirs = {t.logdir for t in trials}
for i in [1, 2, 3]:
assert any(f"t1={i}" in dirpath for dirpath in logdirs)
for t in trials:
assert any(t.trainable_name in dirpath for dirpath in logdirs)
def testEarlyReturn(self): def testEarlyReturn(self):
def train(config, reporter): def train(config, reporter):
reporter(timesteps_total=100, done=True) reporter(timesteps_total=100, done=True)

View file

@ -700,6 +700,7 @@ class _MockTrial(Trial):
self.restored_checkpoint = None self.restored_checkpoint = None
self.resources = Resources(1, 0) self.resources = Resources(1, 0)
self.custom_trial_name = None self.custom_trial_name = None
self.custom_dirname = None
class PopulationBasedTestingSuite(unittest.TestCase): class PopulationBasedTestingSuite(unittest.TestCase):

View file

@ -9,7 +9,6 @@ import platform
import shutil import shutil
import uuid import uuid
import time import time
import tempfile
import os import os
from numbers import Number from numbers import Number
from ray.tune import TuneError from ray.tune import TuneError
@ -129,6 +128,19 @@ class TrialInfo:
return self._trial_id return self._trial_id
def create_logdir(dirname, local_dir):
local_dir = os.path.expanduser(local_dir)
logdir = os.path.join(local_dir, dirname)
if os.path.exists(logdir):
old_dirname = dirname
dirname += "_" + uuid.uuid4().hex[:4]
logger.info(f"Creating a new dirname {dirname} because "
f"trial dirname '{old_dirname}' already exists.")
logdir = os.path.join(local_dir, dirname)
os.makedirs(logdir, exist_ok=True)
return logdir
class Trial: class Trial:
"""A trial object holds the state for one model training run. """A trial object holds the state for one model training run.
@ -176,6 +188,7 @@ class Trial:
export_formats=None, export_formats=None,
restore_path=None, restore_path=None,
trial_name_creator=None, trial_name_creator=None,
trial_dirname_creator=None,
loggers=None, loggers=None,
log_to_file=None, log_to_file=None,
sync_to_driver_fn=None, sync_to_driver_fn=None,
@ -245,6 +258,7 @@ class Trial:
self.error_msg = None self.error_msg = None
self.trial_name_creator = trial_name_creator self.trial_name_creator = trial_name_creator
self.custom_trial_name = None self.custom_trial_name = None
self.custom_dirname = None
# Checkpointing fields # Checkpointing fields
self.saving_to = None self.saving_to = None
@ -283,6 +297,12 @@ class Trial:
if trial_name_creator: if trial_name_creator:
self.custom_trial_name = trial_name_creator(self) self.custom_trial_name = trial_name_creator(self)
if trial_dirname_creator:
self.custom_dirname = trial_dirname_creator(self)
if os.path.sep in self.custom_dirname:
raise ValueError(f"Trial dirname must not contain '/'. "
"Got {self.custom_dirname}")
@property @property
def node_ip(self): def node_ip(self):
return self.location.hostname return self.location.hostname
@ -314,14 +334,6 @@ class Trial:
logdir_name = os.path.basename(self.logdir) logdir_name = os.path.basename(self.logdir)
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name) return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
@classmethod
def create_logdir(cls, identifier, local_dir):
local_dir = os.path.expanduser(local_dir)
os.makedirs(local_dir, exist_ok=True)
return tempfile.mkdtemp(
prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()),
dir=local_dir)
def reset(self): def reset(self):
return Trial( return Trial(
self.trainable_name, self.trainable_name,
@ -351,9 +363,8 @@ class Trial:
"""Init logger.""" """Init logger."""
if not self.result_logger: if not self.result_logger:
if not self.logdir: if not self.logdir:
self.logdir = Trial.create_logdir( self.logdir = create_logdir(self._generate_dirname(),
self._trainable_name() + "_" + self.experiment_tag, self.local_dir)
self.local_dir)
else: else:
os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.logdir, exist_ok=True)
@ -592,6 +603,15 @@ class Trial:
identifier += "_" + self.trial_id identifier += "_" + self.trial_id
return identifier.replace("/", "_") return identifier.replace("/", "_")
def _generate_dirname(self):
if self.custom_dirname:
generated_dirname = self.custom_dirname
else:
generated_dirname = f"{self.trainable_name}_{self.experiment_tag}"
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
generated_dirname += f"_{date_str()}{uuid.uuid4().hex[:8]}"
return generated_dirname.replace("/", "_")
def __getstate__(self): def __getstate__(self):
"""Memento generator for Trial. """Memento generator for Trial.

View file

@ -384,8 +384,8 @@ class TrialRunner:
try: try:
with warn_if_slow("experiment_checkpoint"): with warn_if_slow("experiment_checkpoint"):
self.checkpoint() self.checkpoint()
except Exception: except Exception as e:
logger.exception("Trial Runner checkpointing failed.") logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
self._iteration += 1 self._iteration += 1
if self._server: if self._server:

View file

@ -74,6 +74,7 @@ def run(run_or_experiment,
local_dir=None, local_dir=None,
upload_dir=None, upload_dir=None,
trial_name_creator=None, trial_name_creator=None,
trial_dirname_creator=None,
loggers=None, loggers=None,
log_to_file=False, log_to_file=False,
sync_to_cloud=None, sync_to_cloud=None,
@ -166,8 +167,12 @@ def run(run_or_experiment,
Defaults to ``~/ray_results``. Defaults to ``~/ray_results``.
upload_dir (str): Optional URI to sync training results and checkpoints upload_dir (str): Optional URI to sync training results and checkpoints
to (e.g. ``s3://bucket`` or ``gs://bucket``). to (e.g. ``s3://bucket`` or ``gs://bucket``).
trial_name_creator (func): Optional function for generating trial_name_creator (Callable[[Trial], str]): Optional function
the trial string representation. for generating the trial string representation.
trial_dirname_creator (Callable[[Trial], str]): Function
for generating the trial dirname. This function should take
in a Trial object and return a string representing the
name of the directory. The return value cannot be a path.
loggers (list): List of logger creators to be used with loggers (list): List of logger creators to be used with
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS. each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
See `ray/tune/logger.py`. See `ray/tune/logger.py`.
@ -295,6 +300,7 @@ def run(run_or_experiment,
upload_dir=upload_dir, upload_dir=upload_dir,
sync_to_driver=sync_to_driver, sync_to_driver=sync_to_driver,
trial_name_creator=trial_name_creator, trial_name_creator=trial_name_creator,
trial_dirname_creator=trial_dirname_creator,
loggers=loggers, loggers=loggers,
log_to_file=log_to_file, log_to_file=log_to_file,
checkpoint_freq=checkpoint_freq, checkpoint_freq=checkpoint_freq,
@ -375,8 +381,8 @@ def run(run_or_experiment,
try: try:
runner.checkpoint(force=True) runner.checkpoint(force=True)
except Exception: except Exception as e:
logger.exception("Trial Runner checkpointing failed.") logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
if verbose: if verbose:
_report_progress(runner, progress_reporter, done=True) _report_progress(runner, progress_reporter, done=True)