[tune] custom trial directory name (#10214)

This commit is contained in:
Richard Liaw 2020-08-25 12:52:54 -07:00 committed by GitHub
parent 24a7a8a04d
commit 146d91385c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 91 additions and 21 deletions

View file

@ -189,6 +189,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
# str(None) doesn't create None
restore_path=spec.get("restore"),
trial_name_creator=spec.get("trial_name_creator"),
trial_dirname_creator=spec.get("trial_dirname_creator"),
loggers=spec.get("loggers"),
log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None

View file

@ -108,6 +108,7 @@ class Experiment:
local_dir=None,
upload_dir=None,
trial_name_creator=None,
trial_dirname_creator=None,
loggers=None,
log_to_file=False,
sync_to_driver=None,
@ -173,6 +174,7 @@ class Experiment:
"upload_dir": upload_dir,
"remote_checkpoint_dir": self.remote_checkpoint_dir,
"trial_name_creator": trial_name_creator,
"trial_dirname_creator": trial_dirname_creator,
"loggers": loggers,
"log_to_file": (stdout_file, stderr_file),
"sync_to_driver": sync_to_driver,

View file

@ -1,9 +1,9 @@
import pickle
import os
import copy
import logging
import glob
import ray.cloudpickle as cloudpickle
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
@ -29,7 +29,7 @@ def _atomic_save(state, checkpoint_dir, file_name):
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
".tmp_search_generator_ckpt")
with open(tmp_search_ckpt_path, "wb") as f:
pickle.dump(state, f)
cloudpickle.dump(state, f)
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
@ -41,7 +41,7 @@ def _find_newest_ckpt(dirpath, pattern):
return
most_recent_checkpoint = max(full_paths)
with open(most_recent_checkpoint, "rb") as f:
search_alg_state = pickle.load(f)
search_alg_state = cloudpickle.load(f)
return search_alg_state

View file

@ -33,10 +33,12 @@ from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
shutil.rmtree(self.tmpdir)
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
"""Checks logging is the same between APIs.
@ -547,6 +549,44 @@ class TrainableFunctionApiTest(unittest.TestCase):
with self.assertRaises(TuneError):
tune.run(train, stop=stop)
def testCustomTrialDir(self):
def train(config):
for i in range(10):
tune.report(test=i)
custom_name = "TRAIL_TRIAL"
def custom_trial_dir(trial):
return custom_name
trials = tune.run(
train,
config={
"t1": tune.grid_search([1, 2, 3])
},
trial_dirname_creator=custom_trial_dir,
local_dir=self.tmpdir).trials
logdirs = {t.logdir for t in trials}
assert len(logdirs) == 3
assert all(custom_name in dirpath for dirpath in logdirs)
def testTrialDirRegression(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
trials = tune.run(
train,
config={
"t1": tune.grid_search([1, 2, 3])
},
local_dir=self.tmpdir).trials
logdirs = {t.logdir for t in trials}
for i in [1, 2, 3]:
assert any(f"t1={i}" in dirpath for dirpath in logdirs)
for t in trials:
assert any(t.trainable_name in dirpath for dirpath in logdirs)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)

View file

@ -700,6 +700,7 @@ class _MockTrial(Trial):
self.restored_checkpoint = None
self.resources = Resources(1, 0)
self.custom_trial_name = None
self.custom_dirname = None
class PopulationBasedTestingSuite(unittest.TestCase):

View file

@ -9,7 +9,6 @@ import platform
import shutil
import uuid
import time
import tempfile
import os
from numbers import Number
from ray.tune import TuneError
@ -129,6 +128,19 @@ class TrialInfo:
return self._trial_id
def create_logdir(dirname, local_dir):
local_dir = os.path.expanduser(local_dir)
logdir = os.path.join(local_dir, dirname)
if os.path.exists(logdir):
old_dirname = dirname
dirname += "_" + uuid.uuid4().hex[:4]
logger.info(f"Creating a new dirname {dirname} because "
f"trial dirname '{old_dirname}' already exists.")
logdir = os.path.join(local_dir, dirname)
os.makedirs(logdir, exist_ok=True)
return logdir
class Trial:
"""A trial object holds the state for one model training run.
@ -176,6 +188,7 @@ class Trial:
export_formats=None,
restore_path=None,
trial_name_creator=None,
trial_dirname_creator=None,
loggers=None,
log_to_file=None,
sync_to_driver_fn=None,
@ -245,6 +258,7 @@ class Trial:
self.error_msg = None
self.trial_name_creator = trial_name_creator
self.custom_trial_name = None
self.custom_dirname = None
# Checkpointing fields
self.saving_to = None
@ -283,6 +297,12 @@ class Trial:
if trial_name_creator:
self.custom_trial_name = trial_name_creator(self)
if trial_dirname_creator:
self.custom_dirname = trial_dirname_creator(self)
if os.path.sep in self.custom_dirname:
raise ValueError(f"Trial dirname must not contain '/'. "
"Got {self.custom_dirname}")
@property
def node_ip(self):
return self.location.hostname
@ -314,14 +334,6 @@ class Trial:
logdir_name = os.path.basename(self.logdir)
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
@classmethod
def create_logdir(cls, identifier, local_dir):
local_dir = os.path.expanduser(local_dir)
os.makedirs(local_dir, exist_ok=True)
return tempfile.mkdtemp(
prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()),
dir=local_dir)
def reset(self):
return Trial(
self.trainable_name,
@ -351,9 +363,8 @@ class Trial:
"""Init logger."""
if not self.result_logger:
if not self.logdir:
self.logdir = Trial.create_logdir(
self._trainable_name() + "_" + self.experiment_tag,
self.local_dir)
self.logdir = create_logdir(self._generate_dirname(),
self.local_dir)
else:
os.makedirs(self.logdir, exist_ok=True)
@ -592,6 +603,15 @@ class Trial:
identifier += "_" + self.trial_id
return identifier.replace("/", "_")
def _generate_dirname(self):
if self.custom_dirname:
generated_dirname = self.custom_dirname
else:
generated_dirname = f"{self.trainable_name}_{self.experiment_tag}"
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
generated_dirname += f"_{date_str()}{uuid.uuid4().hex[:8]}"
return generated_dirname.replace("/", "_")
def __getstate__(self):
"""Memento generator for Trial.

View file

@ -384,8 +384,8 @@ class TrialRunner:
try:
with warn_if_slow("experiment_checkpoint"):
self.checkpoint()
except Exception:
logger.exception("Trial Runner checkpointing failed.")
except Exception as e:
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
self._iteration += 1
if self._server:

View file

@ -74,6 +74,7 @@ def run(run_or_experiment,
local_dir=None,
upload_dir=None,
trial_name_creator=None,
trial_dirname_creator=None,
loggers=None,
log_to_file=False,
sync_to_cloud=None,
@ -166,8 +167,12 @@ def run(run_or_experiment,
Defaults to ``~/ray_results``.
upload_dir (str): Optional URI to sync training results and checkpoints
to (e.g. ``s3://bucket`` or ``gs://bucket``).
trial_name_creator (func): Optional function for generating
the trial string representation.
trial_name_creator (Callable[[Trial], str]): Optional function
for generating the trial string representation.
trial_dirname_creator (Callable[[Trial], str]): Function
for generating the trial dirname. This function should take
in a Trial object and return a string representing the
name of the directory. The return value cannot be a path.
loggers (list): List of logger creators to be used with
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
See `ray/tune/logger.py`.
@ -295,6 +300,7 @@ def run(run_or_experiment,
upload_dir=upload_dir,
sync_to_driver=sync_to_driver,
trial_name_creator=trial_name_creator,
trial_dirname_creator=trial_dirname_creator,
loggers=loggers,
log_to_file=log_to_file,
checkpoint_freq=checkpoint_freq,
@ -375,8 +381,8 @@ def run(run_or_experiment,
try:
runner.checkpoint(force=True)
except Exception:
logger.exception("Trial Runner checkpointing failed.")
except Exception as e:
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
if verbose:
_report_progress(runner, progress_reporter, done=True)