mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] custom trial directory name (#10214)
This commit is contained in:
parent
24a7a8a04d
commit
146d91385c
8 changed files with 91 additions and 21 deletions
|
@ -189,6 +189,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
|||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
trial_name_creator=spec.get("trial_name_creator"),
|
||||
trial_dirname_creator=spec.get("trial_dirname_creator"),
|
||||
loggers=spec.get("loggers"),
|
||||
log_to_file=spec.get("log_to_file"),
|
||||
# str(None) doesn't create None
|
||||
|
|
|
@ -108,6 +108,7 @@ class Experiment:
|
|||
local_dir=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=False,
|
||||
sync_to_driver=None,
|
||||
|
@ -173,6 +174,7 @@ class Experiment:
|
|||
"upload_dir": upload_dir,
|
||||
"remote_checkpoint_dir": self.remote_checkpoint_dir,
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"trial_dirname_creator": trial_dirname_creator,
|
||||
"loggers": loggers,
|
||||
"log_to_file": (stdout_file, stderr_file),
|
||||
"sync_to_driver": sync_to_driver,
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import pickle
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import glob
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
|
@ -29,7 +29,7 @@ def _atomic_save(state, checkpoint_dir, file_name):
|
|||
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
||||
".tmp_search_generator_ckpt")
|
||||
with open(tmp_search_ckpt_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
cloudpickle.dump(state, f)
|
||||
|
||||
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
||||
|
||||
|
@ -41,7 +41,7 @@ def _find_newest_ckpt(dirpath, pattern):
|
|||
return
|
||||
most_recent_checkpoint = max(full_paths)
|
||||
with open(most_recent_checkpoint, "rb") as f:
|
||||
search_alg_state = pickle.load(f)
|
||||
search_alg_state = cloudpickle.load(f)
|
||||
return search_alg_state
|
||||
|
||||
|
||||
|
|
|
@ -33,10 +33,12 @@ from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
|
|||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
|
||||
"""Checks logging is the same between APIs.
|
||||
|
@ -547,6 +549,44 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
with self.assertRaises(TuneError):
|
||||
tune.run(train, stop=stop)
|
||||
|
||||
def testCustomTrialDir(self):
|
||||
def train(config):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
|
||||
custom_name = "TRAIL_TRIAL"
|
||||
|
||||
def custom_trial_dir(trial):
|
||||
return custom_name
|
||||
|
||||
trials = tune.run(
|
||||
train,
|
||||
config={
|
||||
"t1": tune.grid_search([1, 2, 3])
|
||||
},
|
||||
trial_dirname_creator=custom_trial_dir,
|
||||
local_dir=self.tmpdir).trials
|
||||
logdirs = {t.logdir for t in trials}
|
||||
assert len(logdirs) == 3
|
||||
assert all(custom_name in dirpath for dirpath in logdirs)
|
||||
|
||||
def testTrialDirRegression(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
trials = tune.run(
|
||||
train,
|
||||
config={
|
||||
"t1": tune.grid_search([1, 2, 3])
|
||||
},
|
||||
local_dir=self.tmpdir).trials
|
||||
logdirs = {t.logdir for t in trials}
|
||||
for i in [1, 2, 3]:
|
||||
assert any(f"t1={i}" in dirpath for dirpath in logdirs)
|
||||
for t in trials:
|
||||
assert any(t.trainable_name in dirpath for dirpath in logdirs)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
|
|
@ -700,6 +700,7 @@ class _MockTrial(Trial):
|
|||
self.restored_checkpoint = None
|
||||
self.resources = Resources(1, 0)
|
||||
self.custom_trial_name = None
|
||||
self.custom_dirname = None
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
|
|
@ -9,7 +9,6 @@ import platform
|
|||
import shutil
|
||||
import uuid
|
||||
import time
|
||||
import tempfile
|
||||
import os
|
||||
from numbers import Number
|
||||
from ray.tune import TuneError
|
||||
|
@ -129,6 +128,19 @@ class TrialInfo:
|
|||
return self._trial_id
|
||||
|
||||
|
||||
def create_logdir(dirname, local_dir):
|
||||
local_dir = os.path.expanduser(local_dir)
|
||||
logdir = os.path.join(local_dir, dirname)
|
||||
if os.path.exists(logdir):
|
||||
old_dirname = dirname
|
||||
dirname += "_" + uuid.uuid4().hex[:4]
|
||||
logger.info(f"Creating a new dirname {dirname} because "
|
||||
f"trial dirname '{old_dirname}' already exists.")
|
||||
logdir = os.path.join(local_dir, dirname)
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
return logdir
|
||||
|
||||
|
||||
class Trial:
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
|
@ -176,6 +188,7 @@ class Trial:
|
|||
export_formats=None,
|
||||
restore_path=None,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=None,
|
||||
sync_to_driver_fn=None,
|
||||
|
@ -245,6 +258,7 @@ class Trial:
|
|||
self.error_msg = None
|
||||
self.trial_name_creator = trial_name_creator
|
||||
self.custom_trial_name = None
|
||||
self.custom_dirname = None
|
||||
|
||||
# Checkpointing fields
|
||||
self.saving_to = None
|
||||
|
@ -283,6 +297,12 @@ class Trial:
|
|||
if trial_name_creator:
|
||||
self.custom_trial_name = trial_name_creator(self)
|
||||
|
||||
if trial_dirname_creator:
|
||||
self.custom_dirname = trial_dirname_creator(self)
|
||||
if os.path.sep in self.custom_dirname:
|
||||
raise ValueError(f"Trial dirname must not contain '/'. "
|
||||
"Got {self.custom_dirname}")
|
||||
|
||||
@property
|
||||
def node_ip(self):
|
||||
return self.location.hostname
|
||||
|
@ -314,14 +334,6 @@ class Trial:
|
|||
logdir_name = os.path.basename(self.logdir)
|
||||
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
|
||||
|
||||
@classmethod
|
||||
def create_logdir(cls, identifier, local_dir):
|
||||
local_dir = os.path.expanduser(local_dir)
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
return tempfile.mkdtemp(
|
||||
prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=local_dir)
|
||||
|
||||
def reset(self):
|
||||
return Trial(
|
||||
self.trainable_name,
|
||||
|
@ -351,9 +363,8 @@ class Trial:
|
|||
"""Init logger."""
|
||||
if not self.result_logger:
|
||||
if not self.logdir:
|
||||
self.logdir = Trial.create_logdir(
|
||||
self._trainable_name() + "_" + self.experiment_tag,
|
||||
self.local_dir)
|
||||
self.logdir = create_logdir(self._generate_dirname(),
|
||||
self.local_dir)
|
||||
else:
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
|
||||
|
@ -592,6 +603,15 @@ class Trial:
|
|||
identifier += "_" + self.trial_id
|
||||
return identifier.replace("/", "_")
|
||||
|
||||
def _generate_dirname(self):
|
||||
if self.custom_dirname:
|
||||
generated_dirname = self.custom_dirname
|
||||
else:
|
||||
generated_dirname = f"{self.trainable_name}_{self.experiment_tag}"
|
||||
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
|
||||
generated_dirname += f"_{date_str()}{uuid.uuid4().hex[:8]}"
|
||||
return generated_dirname.replace("/", "_")
|
||||
|
||||
def __getstate__(self):
|
||||
"""Memento generator for Trial.
|
||||
|
||||
|
|
|
@ -384,8 +384,8 @@ class TrialRunner:
|
|||
try:
|
||||
with warn_if_slow("experiment_checkpoint"):
|
||||
self.checkpoint()
|
||||
except Exception:
|
||||
logger.exception("Trial Runner checkpointing failed.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
|
||||
self._iteration += 1
|
||||
|
||||
if self._server:
|
||||
|
|
|
@ -74,6 +74,7 @@ def run(run_or_experiment,
|
|||
local_dir=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
trial_dirname_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=False,
|
||||
sync_to_cloud=None,
|
||||
|
@ -166,8 +167,12 @@ def run(run_or_experiment,
|
|||
Defaults to ``~/ray_results``.
|
||||
upload_dir (str): Optional URI to sync training results and checkpoints
|
||||
to (e.g. ``s3://bucket`` or ``gs://bucket``).
|
||||
trial_name_creator (func): Optional function for generating
|
||||
the trial string representation.
|
||||
trial_name_creator (Callable[[Trial], str]): Optional function
|
||||
for generating the trial string representation.
|
||||
trial_dirname_creator (Callable[[Trial], str]): Function
|
||||
for generating the trial dirname. This function should take
|
||||
in a Trial object and return a string representing the
|
||||
name of the directory. The return value cannot be a path.
|
||||
loggers (list): List of logger creators to be used with
|
||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||
See `ray/tune/logger.py`.
|
||||
|
@ -295,6 +300,7 @@ def run(run_or_experiment,
|
|||
upload_dir=upload_dir,
|
||||
sync_to_driver=sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
trial_dirname_creator=trial_dirname_creator,
|
||||
loggers=loggers,
|
||||
log_to_file=log_to_file,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
|
@ -375,8 +381,8 @@ def run(run_or_experiment,
|
|||
|
||||
try:
|
||||
runner.checkpoint(force=True)
|
||||
except Exception:
|
||||
logger.exception("Trial Runner checkpointing failed.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
|
||||
|
||||
if verbose:
|
||||
_report_progress(runner, progress_reporter, done=True)
|
||||
|
|
Loading…
Add table
Reference in a new issue