mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] Fix Trial Logging File name (#1466)
This commit is contained in:
parent
f3d2dc0ad4
commit
e5c4d9ea0c
3 changed files with 49 additions and 3 deletions
|
@ -17,6 +17,7 @@ from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
|
|||
from ray.utils import random_string, binary_to_hex
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
MAX_LEN_IDENTIFIER = 130
|
||||
|
||||
|
||||
class Resources(
|
||||
|
@ -337,6 +338,11 @@ class Trial(object):
|
|||
logger_creator=logger_creator)
|
||||
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
|
||||
|
||||
Truncates to MAX_LEN_IDENTIFIER (default is 130) to avoid problems
|
||||
when creating logging directories.
|
||||
"""
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(
|
||||
self.trainable_name, self.config["env"])
|
||||
|
@ -344,4 +350,4 @@ class Trial(object):
|
|||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier
|
||||
return identifier[:MAX_LEN_IDENTIFIER]
|
||||
|
|
|
@ -128,10 +128,17 @@ def _format_vars(resolved_vars):
|
|||
last_string = False
|
||||
pieces.append(k)
|
||||
pieces.reverse()
|
||||
out.append("_".join(pieces) + "=" + str(value))
|
||||
out.append("_".join(pieces) + "=" + _clean_value(value))
|
||||
return ",".join(out)
|
||||
|
||||
|
||||
def _clean_value(value):
|
||||
if isinstance(value, float):
|
||||
return "{:.5}".format(value)
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def _generate_variants(spec):
|
||||
spec = copy.deepcopy(spec)
|
||||
unresolved = _unresolved_values(spec)
|
||||
|
|
|
@ -13,7 +13,7 @@ from ray.tune import Trainable, TuneError
|
|||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial import Trial, Resources, MAX_LEN_IDENTIFIER
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import generate_trials, grid_search, \
|
||||
RecursiveDependencyError
|
||||
|
@ -78,6 +78,19 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
"config": {"a": "b"},
|
||||
}})
|
||||
|
||||
def testLongFilename(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}})
|
||||
|
||||
def testBadParams(self):
|
||||
def f():
|
||||
run_experiments({"foo": {}})
|
||||
|
@ -334,6 +347,26 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
trial.stop(error=True)
|
||||
self.assertEqual(trial.status, Trial.ERROR)
|
||||
|
||||
def testExperimentTagTruncation(self):
|
||||
ray.init()
|
||||
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
experiments = {"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}}
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
self.assertLessEqual(
|
||||
len(str(trial)), MAX_LEN_IDENTIFIER)
|
||||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
|
|
Loading…
Add table
Reference in a new issue