[tune] fix conditional identifier (#5971)

* fix conditional identifier

* fix

* doc
This commit is contained in:
Richard Liaw 2019-10-22 02:00:49 -07:00 committed by GitHub
parent 832b5ce1f6
commit 81dd0dfb0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 4 deletions

View file

@ -102,9 +102,9 @@ class Experiment(object):
"criteria must take exactly 2 parameters.".format(stop))
config = config or {}
run_identifier = Experiment._register_if_needed(run)
self._run_identifier = Experiment._register_if_needed(run)
spec = {
"run": run_identifier,
"run": self._run_identifier,
"stop": stop,
"config": config,
"resources_per_trial": resources_per_trial,
@ -125,7 +125,7 @@ class Experiment(object):
if restore else None
}
self.name = name or run_identifier
self.name = name or self._run_identifier
self.spec = spec
@classmethod
@ -202,6 +202,11 @@ class Experiment(object):
if self.spec["upload_dir"]:
return os.path.join(self.spec["upload_dir"], self.name)
@property
def run_identifier(self):
"""Returns a string representing the trainable identifier."""
return self._run_identifier
def convert_to_experiment_list(experiments):
"""Produces a list of Experiment objects.

View file

@ -4,6 +4,7 @@ from __future__ import print_function
import logging
import time
import six
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
@ -45,6 +46,9 @@ def _make_scheduler(args):
def _check_default_resources_override(run_identifier):
if not isinstance(run_identifier, six.string_types):
# If obscure dtype, assume it is overriden.
return True
trainable_cls = get_trainable_cls(run_identifier)
return hasattr(trainable_cls, "default_resource_request") and (
trainable_cls.default_resource_request.__code__ !=
@ -265,7 +269,7 @@ def run(run_or_experiment,
dict) and "gpu" in resources_per_trial:
# "gpu" is manually set.
pass
elif _check_default_resources_override(run_identifier):
elif _check_default_resources_override(experiment.run_identifier):
# "default_resources" is manually overriden.
pass
else: