[tune] Better handling of tune.function in global checkpoint (#4519)

Enables result keys to be queried by CLI.
This commit is contained in:
Richard Liaw 2019-04-04 21:08:47 -07:00 committed by GitHub
parent fb88f7efe6
commit 50b2aa0740
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 6 deletions

View file

@ -58,7 +58,7 @@ class sample_from(object):
"""Specify that tune should sample configuration values from this function.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.eval() or tune.function().
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: An callable function to draw a sample from.
@ -67,12 +67,18 @@ class sample_from(object):
def __init__(self, func):
self.func = func
def __str__(self):
return "tune.sample_from({})".format(str(self.func))
def __repr__(self):
return "tune.sample_from({})".format(repr(self.func))
class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.eval() or tune.function().
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: A function literal.
@ -84,6 +90,12 @@ class function(object):
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __str__(self):
return "tune.function({})".format(str(self.func))
def __repr__(self):
return "tune.function({})".format(repr(self.func))
_STANDARD_IMPORTS = {
"random": random,

View file

@ -310,10 +310,8 @@ class Trial(object):
self._nonjson_fields = [
"_checkpoint",
"config",
"loggers",
"sync_function",
"last_result",
"results",
"best_result",
"param_config",

View file

@ -11,12 +11,15 @@ import re
import time
import traceback
import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trial import Trial, Checkpoint
from ray.tune.suggest import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.util import warn_if_slow
from ray.utils import binary_to_hex, hex_to_binary
from ray.tune.web_server import TuneServer
MAX_DEBUG_TRIALS = 20
@ -39,6 +42,27 @@ def _find_newest_ckpt(ckpt_dir):
return max(full_paths)
class _TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, function):
return {
"_type": "function",
"value": binary_to_hex(cloudpickle.dumps(obj))
}
return super(_TuneFunctionEncoder, self).default(obj)
class _TuneFunctionDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(
self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, obj):
if obj.get("_type") == "function":
return cloudpickle.loads(hex_to_binary(obj["value"]))
return obj
class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@ -150,7 +174,7 @@ class TrialRunner(object):
tmp_file_name = os.path.join(metadata_checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2)
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
os.rename(
tmp_file_name,
@ -183,7 +207,7 @@ class TrialRunner(object):
newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f)
runner_state = json.load(f, cls=_TuneFunctionDecoder)
logger.warning("".join([
"Attempting to resume experiment from {}. ".format(