mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Better handling of tune.function in global checkpoint (#4519)
Enables result keys to be queried by CLI.
This commit is contained in:
parent
fb88f7efe6
commit
50b2aa0740
3 changed files with 40 additions and 6 deletions
|
@ -58,7 +58,7 @@ class sample_from(object):
|
|||
"""Specify that tune should sample configuration values from this function.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.eval() or tune.function().
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: An callable function to draw a sample from.
|
||||
|
@ -67,12 +67,18 @@ class sample_from(object):
|
|||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __str__(self):
|
||||
return "tune.sample_from({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.sample_from({})".format(repr(self.func))
|
||||
|
||||
|
||||
class function(object):
|
||||
"""Wraps `func` to make sure it is not expanded during resolution.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.eval() or tune.function().
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: A function literal.
|
||||
|
@ -84,6 +90,12 @@ class function(object):
|
|||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return "tune.function({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.function({})".format(repr(self.func))
|
||||
|
||||
|
||||
_STANDARD_IMPORTS = {
|
||||
"random": random,
|
||||
|
|
|
@ -310,10 +310,8 @@ class Trial(object):
|
|||
|
||||
self._nonjson_fields = [
|
||||
"_checkpoint",
|
||||
"config",
|
||||
"loggers",
|
||||
"sync_function",
|
||||
"last_result",
|
||||
"results",
|
||||
"best_result",
|
||||
"param_config",
|
||||
|
|
|
@ -11,12 +11,15 @@ import re
|
|||
import time
|
||||
import traceback
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.suggest import function
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.util import warn_if_slow
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.tune.web_server import TuneServer
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
@ -39,6 +42,27 @@ def _find_newest_ckpt(ckpt_dir):
|
|||
return max(full_paths)
|
||||
|
||||
|
||||
class _TuneFunctionEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, function):
|
||||
return {
|
||||
"_type": "function",
|
||||
"value": binary_to_hex(cloudpickle.dumps(obj))
|
||||
}
|
||||
return super(_TuneFunctionEncoder, self).default(obj)
|
||||
|
||||
|
||||
class _TuneFunctionDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(
|
||||
self, object_hook=self.object_hook, *args, **kwargs)
|
||||
|
||||
def object_hook(self, obj):
|
||||
if obj.get("_type") == "function":
|
||||
return cloudpickle.loads(hex_to_binary(obj["value"]))
|
||||
return obj
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
|
@ -150,7 +174,7 @@ class TrialRunner(object):
|
|||
tmp_file_name = os.path.join(metadata_checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
with open(tmp_file_name, "w") as f:
|
||||
json.dump(runner_state, f, indent=2)
|
||||
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
|
||||
|
||||
os.rename(
|
||||
tmp_file_name,
|
||||
|
@ -183,7 +207,7 @@ class TrialRunner(object):
|
|||
|
||||
newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f)
|
||||
runner_state = json.load(f, cls=_TuneFunctionDecoder)
|
||||
|
||||
logger.warning("".join([
|
||||
"Attempting to resume experiment from {}. ".format(
|
||||
|
|
Loading…
Add table
Reference in a new issue