mirror of
https://github.com/vale981/ray
synced 2025-03-08 11:31:40 -05:00
[tune] custom trial directory name (#10214)
This commit is contained in:
parent
24a7a8a04d
commit
146d91385c
8 changed files with 91 additions and 21 deletions
|
@ -189,6 +189,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||||
# str(None) doesn't create None
|
# str(None) doesn't create None
|
||||||
restore_path=spec.get("restore"),
|
restore_path=spec.get("restore"),
|
||||||
trial_name_creator=spec.get("trial_name_creator"),
|
trial_name_creator=spec.get("trial_name_creator"),
|
||||||
|
trial_dirname_creator=spec.get("trial_dirname_creator"),
|
||||||
loggers=spec.get("loggers"),
|
loggers=spec.get("loggers"),
|
||||||
log_to_file=spec.get("log_to_file"),
|
log_to_file=spec.get("log_to_file"),
|
||||||
# str(None) doesn't create None
|
# str(None) doesn't create None
|
||||||
|
|
|
@ -108,6 +108,7 @@ class Experiment:
|
||||||
local_dir=None,
|
local_dir=None,
|
||||||
upload_dir=None,
|
upload_dir=None,
|
||||||
trial_name_creator=None,
|
trial_name_creator=None,
|
||||||
|
trial_dirname_creator=None,
|
||||||
loggers=None,
|
loggers=None,
|
||||||
log_to_file=False,
|
log_to_file=False,
|
||||||
sync_to_driver=None,
|
sync_to_driver=None,
|
||||||
|
@ -173,6 +174,7 @@ class Experiment:
|
||||||
"upload_dir": upload_dir,
|
"upload_dir": upload_dir,
|
||||||
"remote_checkpoint_dir": self.remote_checkpoint_dir,
|
"remote_checkpoint_dir": self.remote_checkpoint_dir,
|
||||||
"trial_name_creator": trial_name_creator,
|
"trial_name_creator": trial_name_creator,
|
||||||
|
"trial_dirname_creator": trial_dirname_creator,
|
||||||
"loggers": loggers,
|
"loggers": loggers,
|
||||||
"log_to_file": (stdout_file, stderr_file),
|
"log_to_file": (stdout_file, stderr_file),
|
||||||
"sync_to_driver": sync_to_driver,
|
"sync_to_driver": sync_to_driver,
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import pickle
|
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
|
import ray.cloudpickle as cloudpickle
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
from ray.tune.experiment import convert_to_experiment_list
|
from ray.tune.experiment import convert_to_experiment_list
|
||||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||||
|
@ -29,7 +29,7 @@ def _atomic_save(state, checkpoint_dir, file_name):
|
||||||
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
||||||
".tmp_search_generator_ckpt")
|
".tmp_search_generator_ckpt")
|
||||||
with open(tmp_search_ckpt_path, "wb") as f:
|
with open(tmp_search_ckpt_path, "wb") as f:
|
||||||
pickle.dump(state, f)
|
cloudpickle.dump(state, f)
|
||||||
|
|
||||||
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ def _find_newest_ckpt(dirpath, pattern):
|
||||||
return
|
return
|
||||||
most_recent_checkpoint = max(full_paths)
|
most_recent_checkpoint = max(full_paths)
|
||||||
with open(most_recent_checkpoint, "rb") as f:
|
with open(most_recent_checkpoint, "rb") as f:
|
||||||
search_alg_state = pickle.load(f)
|
search_alg_state = cloudpickle.load(f)
|
||||||
return search_alg_state
|
return search_alg_state
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,12 @@ from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
|
||||||
class TrainableFunctionApiTest(unittest.TestCase):
|
class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
|
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
|
||||||
|
self.tmpdir = tempfile.mkdtemp()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
_register_all() # re-register the evicted objects
|
_register_all() # re-register the evicted objects
|
||||||
|
shutil.rmtree(self.tmpdir)
|
||||||
|
|
||||||
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
|
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
|
||||||
"""Checks logging is the same between APIs.
|
"""Checks logging is the same between APIs.
|
||||||
|
@ -547,6 +549,44 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
with self.assertRaises(TuneError):
|
with self.assertRaises(TuneError):
|
||||||
tune.run(train, stop=stop)
|
tune.run(train, stop=stop)
|
||||||
|
|
||||||
|
def testCustomTrialDir(self):
|
||||||
|
def train(config):
|
||||||
|
for i in range(10):
|
||||||
|
tune.report(test=i)
|
||||||
|
|
||||||
|
custom_name = "TRAIL_TRIAL"
|
||||||
|
|
||||||
|
def custom_trial_dir(trial):
|
||||||
|
return custom_name
|
||||||
|
|
||||||
|
trials = tune.run(
|
||||||
|
train,
|
||||||
|
config={
|
||||||
|
"t1": tune.grid_search([1, 2, 3])
|
||||||
|
},
|
||||||
|
trial_dirname_creator=custom_trial_dir,
|
||||||
|
local_dir=self.tmpdir).trials
|
||||||
|
logdirs = {t.logdir for t in trials}
|
||||||
|
assert len(logdirs) == 3
|
||||||
|
assert all(custom_name in dirpath for dirpath in logdirs)
|
||||||
|
|
||||||
|
def testTrialDirRegression(self):
|
||||||
|
def train(config, reporter):
|
||||||
|
for i in range(10):
|
||||||
|
reporter(test=i)
|
||||||
|
|
||||||
|
trials = tune.run(
|
||||||
|
train,
|
||||||
|
config={
|
||||||
|
"t1": tune.grid_search([1, 2, 3])
|
||||||
|
},
|
||||||
|
local_dir=self.tmpdir).trials
|
||||||
|
logdirs = {t.logdir for t in trials}
|
||||||
|
for i in [1, 2, 3]:
|
||||||
|
assert any(f"t1={i}" in dirpath for dirpath in logdirs)
|
||||||
|
for t in trials:
|
||||||
|
assert any(t.trainable_name in dirpath for dirpath in logdirs)
|
||||||
|
|
||||||
def testEarlyReturn(self):
|
def testEarlyReturn(self):
|
||||||
def train(config, reporter):
|
def train(config, reporter):
|
||||||
reporter(timesteps_total=100, done=True)
|
reporter(timesteps_total=100, done=True)
|
||||||
|
|
|
@ -700,6 +700,7 @@ class _MockTrial(Trial):
|
||||||
self.restored_checkpoint = None
|
self.restored_checkpoint = None
|
||||||
self.resources = Resources(1, 0)
|
self.resources = Resources(1, 0)
|
||||||
self.custom_trial_name = None
|
self.custom_trial_name = None
|
||||||
|
self.custom_dirname = None
|
||||||
|
|
||||||
|
|
||||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||||
|
|
|
@ -9,7 +9,6 @@ import platform
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import tempfile
|
|
||||||
import os
|
import os
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from ray.tune import TuneError
|
from ray.tune import TuneError
|
||||||
|
@ -129,6 +128,19 @@ class TrialInfo:
|
||||||
return self._trial_id
|
return self._trial_id
|
||||||
|
|
||||||
|
|
||||||
|
def create_logdir(dirname, local_dir):
|
||||||
|
local_dir = os.path.expanduser(local_dir)
|
||||||
|
logdir = os.path.join(local_dir, dirname)
|
||||||
|
if os.path.exists(logdir):
|
||||||
|
old_dirname = dirname
|
||||||
|
dirname += "_" + uuid.uuid4().hex[:4]
|
||||||
|
logger.info(f"Creating a new dirname {dirname} because "
|
||||||
|
f"trial dirname '{old_dirname}' already exists.")
|
||||||
|
logdir = os.path.join(local_dir, dirname)
|
||||||
|
os.makedirs(logdir, exist_ok=True)
|
||||||
|
return logdir
|
||||||
|
|
||||||
|
|
||||||
class Trial:
|
class Trial:
|
||||||
"""A trial object holds the state for one model training run.
|
"""A trial object holds the state for one model training run.
|
||||||
|
|
||||||
|
@ -176,6 +188,7 @@ class Trial:
|
||||||
export_formats=None,
|
export_formats=None,
|
||||||
restore_path=None,
|
restore_path=None,
|
||||||
trial_name_creator=None,
|
trial_name_creator=None,
|
||||||
|
trial_dirname_creator=None,
|
||||||
loggers=None,
|
loggers=None,
|
||||||
log_to_file=None,
|
log_to_file=None,
|
||||||
sync_to_driver_fn=None,
|
sync_to_driver_fn=None,
|
||||||
|
@ -245,6 +258,7 @@ class Trial:
|
||||||
self.error_msg = None
|
self.error_msg = None
|
||||||
self.trial_name_creator = trial_name_creator
|
self.trial_name_creator = trial_name_creator
|
||||||
self.custom_trial_name = None
|
self.custom_trial_name = None
|
||||||
|
self.custom_dirname = None
|
||||||
|
|
||||||
# Checkpointing fields
|
# Checkpointing fields
|
||||||
self.saving_to = None
|
self.saving_to = None
|
||||||
|
@ -283,6 +297,12 @@ class Trial:
|
||||||
if trial_name_creator:
|
if trial_name_creator:
|
||||||
self.custom_trial_name = trial_name_creator(self)
|
self.custom_trial_name = trial_name_creator(self)
|
||||||
|
|
||||||
|
if trial_dirname_creator:
|
||||||
|
self.custom_dirname = trial_dirname_creator(self)
|
||||||
|
if os.path.sep in self.custom_dirname:
|
||||||
|
raise ValueError(f"Trial dirname must not contain '/'. "
|
||||||
|
"Got {self.custom_dirname}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_ip(self):
|
def node_ip(self):
|
||||||
return self.location.hostname
|
return self.location.hostname
|
||||||
|
@ -314,14 +334,6 @@ class Trial:
|
||||||
logdir_name = os.path.basename(self.logdir)
|
logdir_name = os.path.basename(self.logdir)
|
||||||
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
|
return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_logdir(cls, identifier, local_dir):
|
|
||||||
local_dir = os.path.expanduser(local_dir)
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
return tempfile.mkdtemp(
|
|
||||||
prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()),
|
|
||||||
dir=local_dir)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return Trial(
|
return Trial(
|
||||||
self.trainable_name,
|
self.trainable_name,
|
||||||
|
@ -351,9 +363,8 @@ class Trial:
|
||||||
"""Init logger."""
|
"""Init logger."""
|
||||||
if not self.result_logger:
|
if not self.result_logger:
|
||||||
if not self.logdir:
|
if not self.logdir:
|
||||||
self.logdir = Trial.create_logdir(
|
self.logdir = create_logdir(self._generate_dirname(),
|
||||||
self._trainable_name() + "_" + self.experiment_tag,
|
self.local_dir)
|
||||||
self.local_dir)
|
|
||||||
else:
|
else:
|
||||||
os.makedirs(self.logdir, exist_ok=True)
|
os.makedirs(self.logdir, exist_ok=True)
|
||||||
|
|
||||||
|
@ -592,6 +603,15 @@ class Trial:
|
||||||
identifier += "_" + self.trial_id
|
identifier += "_" + self.trial_id
|
||||||
return identifier.replace("/", "_")
|
return identifier.replace("/", "_")
|
||||||
|
|
||||||
|
def _generate_dirname(self):
|
||||||
|
if self.custom_dirname:
|
||||||
|
generated_dirname = self.custom_dirname
|
||||||
|
else:
|
||||||
|
generated_dirname = f"{self.trainable_name}_{self.experiment_tag}"
|
||||||
|
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
|
||||||
|
generated_dirname += f"_{date_str()}{uuid.uuid4().hex[:8]}"
|
||||||
|
return generated_dirname.replace("/", "_")
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
"""Memento generator for Trial.
|
"""Memento generator for Trial.
|
||||||
|
|
||||||
|
|
|
@ -384,8 +384,8 @@ class TrialRunner:
|
||||||
try:
|
try:
|
||||||
with warn_if_slow("experiment_checkpoint"):
|
with warn_if_slow("experiment_checkpoint"):
|
||||||
self.checkpoint()
|
self.checkpoint()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Trial Runner checkpointing failed.")
|
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
|
||||||
self._iteration += 1
|
self._iteration += 1
|
||||||
|
|
||||||
if self._server:
|
if self._server:
|
||||||
|
|
|
@ -74,6 +74,7 @@ def run(run_or_experiment,
|
||||||
local_dir=None,
|
local_dir=None,
|
||||||
upload_dir=None,
|
upload_dir=None,
|
||||||
trial_name_creator=None,
|
trial_name_creator=None,
|
||||||
|
trial_dirname_creator=None,
|
||||||
loggers=None,
|
loggers=None,
|
||||||
log_to_file=False,
|
log_to_file=False,
|
||||||
sync_to_cloud=None,
|
sync_to_cloud=None,
|
||||||
|
@ -166,8 +167,12 @@ def run(run_or_experiment,
|
||||||
Defaults to ``~/ray_results``.
|
Defaults to ``~/ray_results``.
|
||||||
upload_dir (str): Optional URI to sync training results and checkpoints
|
upload_dir (str): Optional URI to sync training results and checkpoints
|
||||||
to (e.g. ``s3://bucket`` or ``gs://bucket``).
|
to (e.g. ``s3://bucket`` or ``gs://bucket``).
|
||||||
trial_name_creator (func): Optional function for generating
|
trial_name_creator (Callable[[Trial], str]): Optional function
|
||||||
the trial string representation.
|
for generating the trial string representation.
|
||||||
|
trial_dirname_creator (Callable[[Trial], str]): Function
|
||||||
|
for generating the trial dirname. This function should take
|
||||||
|
in a Trial object and return a string representing the
|
||||||
|
name of the directory. The return value cannot be a path.
|
||||||
loggers (list): List of logger creators to be used with
|
loggers (list): List of logger creators to be used with
|
||||||
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
each Trial. If None, defaults to ray.tune.logger.DEFAULT_LOGGERS.
|
||||||
See `ray/tune/logger.py`.
|
See `ray/tune/logger.py`.
|
||||||
|
@ -295,6 +300,7 @@ def run(run_or_experiment,
|
||||||
upload_dir=upload_dir,
|
upload_dir=upload_dir,
|
||||||
sync_to_driver=sync_to_driver,
|
sync_to_driver=sync_to_driver,
|
||||||
trial_name_creator=trial_name_creator,
|
trial_name_creator=trial_name_creator,
|
||||||
|
trial_dirname_creator=trial_dirname_creator,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
log_to_file=log_to_file,
|
log_to_file=log_to_file,
|
||||||
checkpoint_freq=checkpoint_freq,
|
checkpoint_freq=checkpoint_freq,
|
||||||
|
@ -375,8 +381,8 @@ def run(run_or_experiment,
|
||||||
|
|
||||||
try:
|
try:
|
||||||
runner.checkpoint(force=True)
|
runner.checkpoint(force=True)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Trial Runner checkpointing failed.")
|
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
_report_progress(runner, progress_reporter, done=True)
|
_report_progress(runner, progress_reporter, done=True)
|
||||||
|
|
Loading…
Add table
Reference in a new issue