[tune] Fix Trial Logging File name (#1466)

This commit is contained in:
Richard Liaw 2018-01-25 17:57:40 -08:00 committed by GitHub
parent f3d2dc0ad4
commit e5c4d9ea0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 3 deletions

View file

@ -17,6 +17,7 @@ from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print
from ray.utils import random_string, binary_to_hex
DEBUG_PRINT_INTERVAL = 5
MAX_LEN_IDENTIFIER = 130
class Resources(
@ -337,6 +338,11 @@ class Trial(object):
logger_creator=logger_creator)
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``.
Truncates to MAX_LEN_IDENTIFIER (default is 130) to avoid problems
when creating logging directories.
"""
if "env" in self.config:
identifier = "{}_{}".format(
self.trainable_name, self.config["env"])
@ -344,4 +350,4 @@ class Trial(object):
identifier = self.trainable_name
if self.experiment_tag:
identifier += "_" + self.experiment_tag
return identifier
return identifier[:MAX_LEN_IDENTIFIER]

View file

@ -128,10 +128,17 @@ def _format_vars(resolved_vars):
last_string = False
pieces.append(k)
pieces.reverse()
out.append("_".join(pieces) + "=" + str(value))
out.append("_".join(pieces) + "=" + _clean_value(value))
return ",".join(out)
def _clean_value(value):
if isinstance(value, float):
return "{:.5}".format(value)
else:
return str(value)
def _generate_variants(spec):
spec = copy.deepcopy(spec)
unresolved = _unresolved_values(spec)

View file

@ -13,7 +13,7 @@ from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Trial, Resources
from ray.tune.trial import Trial, Resources, MAX_LEN_IDENTIFIER
from ray.tune.trial_runner import TrialRunner
from ray.tune.variant_generator import generate_trials, grid_search, \
RecursiveDependencyError
@ -78,6 +78,19 @@ class TrainableFunctionApiTest(unittest.TestCase):
"config": {"a": "b"},
}})
def testLongFilename(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}})
def testBadParams(self):
def f():
run_experiments({"foo": {}})
@ -334,6 +347,26 @@ class TrialRunnerTest(unittest.TestCase):
trial.stop(error=True)
self.assertEqual(trial.status, Trial.ERROR)
def testExperimentTagTruncation(self):
ray.init()
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
experiments = {"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}}
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
self.assertLessEqual(
len(str(trial)), MAX_LEN_IDENTIFIER)
def testTrialErrorOnStart(self):
ray.init()
_default_registry.register(TRAINABLE_CLASS, "asdf", None)