diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 5b8e2f293..72b867c8f 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -58,7 +58,7 @@ class sample_from(object): """Specify that tune should sample configuration values from this function. The use of function arguments in tune configs must be disambiguated by - either wrapped the function in tune.eval() or tune.function(). + either wrapped the function in tune.sample_from() or tune.function(). Arguments: func: An callable function to draw a sample from. @@ -67,12 +67,18 @@ class sample_from(object): def __init__(self, func): self.func = func + def __str__(self): + return "tune.sample_from({})".format(str(self.func)) + + def __repr__(self): + return "tune.sample_from({})".format(repr(self.func)) + class function(object): """Wraps `func` to make sure it is not expanded during resolution. The use of function arguments in tune configs must be disambiguated by - either wrapped the function in tune.eval() or tune.function(). + either wrapped the function in tune.sample_from() or tune.function(). Arguments: func: A function literal. @@ -84,6 +90,12 @@ class function(object): def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) + def __str__(self): + return "tune.function({})".format(str(self.func)) + + def __repr__(self): + return "tune.function({})".format(repr(self.func)) + _STANDARD_IMPORTS = { "random": random, diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index fb324fcb7..bb2c3ae50 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -310,10 +310,8 @@ class Trial(object): self._nonjson_fields = [ "_checkpoint", - "config", "loggers", "sync_function", - "last_result", "results", "best_result", "param_config", diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 6be2b7375..4348f7e0d 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -11,12 +11,15 @@ import re import time import traceback +import ray.cloudpickle as cloudpickle from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE from ray.tune.trial import Trial, Checkpoint +from ray.tune.suggest import function from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.util import warn_if_slow +from ray.utils import binary_to_hex, hex_to_binary from ray.tune.web_server import TuneServer MAX_DEBUG_TRIALS = 20 @@ -39,6 +42,27 @@ def _find_newest_ckpt(ckpt_dir): return max(full_paths) +class _TuneFunctionEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, function): + return { + "_type": "function", + "value": binary_to_hex(cloudpickle.dumps(obj)) + } + return super(_TuneFunctionEncoder, self).default(obj) + + +class _TuneFunctionDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__( + self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, obj): + if obj.get("_type") == "function": + return cloudpickle.loads(hex_to_binary(obj["value"])) + return obj + + class TrialRunner(object): """A TrialRunner implements the event loop for scheduling trials on Ray. @@ -150,7 +174,7 @@ class TrialRunner(object): tmp_file_name = os.path.join(metadata_checkpoint_dir, ".tmp_checkpoint") with open(tmp_file_name, "w") as f: - json.dump(runner_state, f, indent=2) + json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder) os.rename( tmp_file_name, @@ -183,7 +207,7 @@ class TrialRunner(object): newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir) with open(newest_ckpt_path, "r") as f: - runner_state = json.load(f) + runner_state = json.load(f, cls=_TuneFunctionDecoder) logger.warning("".join([ "Attempting to resume experiment from {}. ".format(