From 146d91385c0a06eb1b7690045b222c93fd1059ba Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 25 Aug 2020 12:52:54 -0700 Subject: [PATCH] [tune] custom trial directory name (#10214) --- python/ray/tune/config_parser.py | 1 + python/ray/tune/experiment.py | 2 + python/ray/tune/suggest/search_generator.py | 6 +-- python/ray/tune/tests/test_api.py | 40 +++++++++++++++++ python/ray/tune/tests/test_trial_scheduler.py | 1 + python/ray/tune/trial.py | 44 ++++++++++++++----- python/ray/tune/trial_runner.py | 4 +- python/ray/tune/tune.py | 14 ++++-- 8 files changed, 91 insertions(+), 21 deletions(-) diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 72903fd9d..bb1efe836 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -189,6 +189,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): # str(None) doesn't create None restore_path=spec.get("restore"), trial_name_creator=spec.get("trial_name_creator"), + trial_dirname_creator=spec.get("trial_dirname_creator"), loggers=spec.get("loggers"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index b6034eb67..66410d542 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -108,6 +108,7 @@ class Experiment: local_dir=None, upload_dir=None, trial_name_creator=None, + trial_dirname_creator=None, loggers=None, log_to_file=False, sync_to_driver=None, @@ -173,6 +174,7 @@ class Experiment: "upload_dir": upload_dir, "remote_checkpoint_dir": self.remote_checkpoint_dir, "trial_name_creator": trial_name_creator, + "trial_dirname_creator": trial_dirname_creator, "loggers": loggers, "log_to_file": (stdout_file, stderr_file), "sync_to_driver": sync_to_driver, diff --git a/python/ray/tune/suggest/search_generator.py b/python/ray/tune/suggest/search_generator.py index f9763d2ed..bc23247d9 100644 --- a/python/ray/tune/suggest/search_generator.py +++ b/python/ray/tune/suggest/search_generator.py @@ -1,9 +1,9 @@ -import pickle import os import copy import logging import glob +import ray.cloudpickle as cloudpickle from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec @@ -29,7 +29,7 @@ def _atomic_save(state, checkpoint_dir, file_name): tmp_search_ckpt_path = os.path.join(checkpoint_dir, ".tmp_search_generator_ckpt") with open(tmp_search_ckpt_path, "wb") as f: - pickle.dump(state, f) + cloudpickle.dump(state, f) os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name)) @@ -41,7 +41,7 @@ def _find_newest_ckpt(dirpath, pattern): return most_recent_checkpoint = max(full_paths) with open(most_recent_checkpoint, "rb") as f: - search_alg_state = pickle.load(f) + search_alg_state = cloudpickle.load(f) return search_alg_state diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 4a2fa30cf..5c4322658 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -33,10 +33,12 @@ from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR class TrainableFunctionApiTest(unittest.TestCase): def setUp(self): ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024) + self.tmpdir = tempfile.mkdtemp() def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects + shutil.rmtree(self.tmpdir) def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None): """Checks logging is the same between APIs. @@ -547,6 +549,44 @@ class TrainableFunctionApiTest(unittest.TestCase): with self.assertRaises(TuneError): tune.run(train, stop=stop) + def testCustomTrialDir(self): + def train(config): + for i in range(10): + tune.report(test=i) + + custom_name = "TRAIL_TRIAL" + + def custom_trial_dir(trial): + return custom_name + + trials = tune.run( + train, + config={ + "t1": tune.grid_search([1, 2, 3]) + }, + trial_dirname_creator=custom_trial_dir, + local_dir=self.tmpdir).trials + logdirs = {t.logdir for t in trials} + assert len(logdirs) == 3 + assert all(custom_name in dirpath for dirpath in logdirs) + + def testTrialDirRegression(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + trials = tune.run( + train, + config={ + "t1": tune.grid_search([1, 2, 3]) + }, + local_dir=self.tmpdir).trials + logdirs = {t.logdir for t in trials} + for i in [1, 2, 3]: + assert any(f"t1={i}" in dirpath for dirpath in logdirs) + for t in trials: + assert any(t.trainable_name in dirpath for dirpath in logdirs) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 7a9b35f81..30c3649bc 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -700,6 +700,7 @@ class _MockTrial(Trial): self.restored_checkpoint = None self.resources = Resources(1, 0) self.custom_trial_name = None + self.custom_dirname = None class PopulationBasedTestingSuite(unittest.TestCase): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 879623adb..409a8b120 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -9,7 +9,6 @@ import platform import shutil import uuid import time -import tempfile import os from numbers import Number from ray.tune import TuneError @@ -129,6 +128,19 @@ class TrialInfo: return self._trial_id +def create_logdir(dirname, local_dir): + local_dir = os.path.expanduser(local_dir) + logdir = os.path.join(local_dir, dirname) + if os.path.exists(logdir): + old_dirname = dirname + dirname += "_" + uuid.uuid4().hex[:4] + logger.info(f"Creating a new dirname {dirname} because " + f"trial dirname '{old_dirname}' already exists.") + logdir = os.path.join(local_dir, dirname) + os.makedirs(logdir, exist_ok=True) + return logdir + + class Trial: """A trial object holds the state for one model training run. @@ -176,6 +188,7 @@ class Trial: export_formats=None, restore_path=None, trial_name_creator=None, + trial_dirname_creator=None, loggers=None, log_to_file=None, sync_to_driver_fn=None, @@ -245,6 +258,7 @@ class Trial: self.error_msg = None self.trial_name_creator = trial_name_creator self.custom_trial_name = None + self.custom_dirname = None # Checkpointing fields self.saving_to = None @@ -283,6 +297,12 @@ class Trial: if trial_name_creator: self.custom_trial_name = trial_name_creator(self) + if trial_dirname_creator: + self.custom_dirname = trial_dirname_creator(self) + if os.path.sep in self.custom_dirname: + raise ValueError(f"Trial dirname must not contain '/'. " + "Got {self.custom_dirname}") + @property def node_ip(self): return self.location.hostname @@ -314,14 +334,6 @@ class Trial: logdir_name = os.path.basename(self.logdir) return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name) - @classmethod - def create_logdir(cls, identifier, local_dir): - local_dir = os.path.expanduser(local_dir) - os.makedirs(local_dir, exist_ok=True) - return tempfile.mkdtemp( - prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()), - dir=local_dir) - def reset(self): return Trial( self.trainable_name, @@ -351,9 +363,8 @@ class Trial: """Init logger.""" if not self.result_logger: if not self.logdir: - self.logdir = Trial.create_logdir( - self._trainable_name() + "_" + self.experiment_tag, - self.local_dir) + self.logdir = create_logdir(self._generate_dirname(), + self.local_dir) else: os.makedirs(self.logdir, exist_ok=True) @@ -592,6 +603,15 @@ class Trial: identifier += "_" + self.trial_id return identifier.replace("/", "_") + def _generate_dirname(self): + if self.custom_dirname: + generated_dirname = self.custom_dirname + else: + generated_dirname = f"{self.trainable_name}_{self.experiment_tag}" + generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER] + generated_dirname += f"_{date_str()}{uuid.uuid4().hex[:8]}" + return generated_dirname.replace("/", "_") + def __getstate__(self): """Memento generator for Trial. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 222731714..8e644c452 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -384,8 +384,8 @@ class TrialRunner: try: with warn_if_slow("experiment_checkpoint"): self.checkpoint() - except Exception: - logger.exception("Trial Runner checkpointing failed.") + except Exception as e: + logger.warning(f"Trial Runner checkpointing failed: {str(e)}") self._iteration += 1 if self._server: diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 6fd9e8580..543de99a5 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -74,6 +74,7 @@ def run(run_or_experiment, local_dir=None, upload_dir=None, trial_name_creator=None, + trial_dirname_creator=None, loggers=None, log_to_file=False, sync_to_cloud=None, @@ -166,8 +167,12 @@ def run(run_or_experiment, Defaults to ``~/ray_results``. upload_dir (str): Optional URI to sync training results and checkpoints to (e.g. ``s3://bucket`` or ``gs://bucket``). - trial_name_creator (func): Optional function for generating - the trial string representation. + trial_name_creator (Callable[[Trial], str]): Optional function + for generating the trial string representation. + trial_dirname_creator (Callable[[Trial], str]): Function + for generating the trial dirname. This function should take + in a Trial object and return a string representing the + name of the directory. The return value cannot be a path. loggers (list): List of logger creators to be used with each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS. See `ray/tune/logger.py`. @@ -295,6 +300,7 @@ def run(run_or_experiment, upload_dir=upload_dir, sync_to_driver=sync_to_driver, trial_name_creator=trial_name_creator, + trial_dirname_creator=trial_dirname_creator, loggers=loggers, log_to_file=log_to_file, checkpoint_freq=checkpoint_freq, @@ -375,8 +381,8 @@ def run(run_or_experiment, try: runner.checkpoint(force=True) - except Exception: - logger.exception("Trial Runner checkpointing failed.") + except Exception as e: + logger.warning(f"Trial Runner checkpointing failed: {str(e)}") if verbose: _report_progress(runner, progress_reporter, done=True)