diff --git a/python/ray/tune/durable_trainable.py b/python/ray/tune/durable_trainable.py index db822434f..5e4f3f080 100644 --- a/python/ray/tune/durable_trainable.py +++ b/python/ray/tune/durable_trainable.py @@ -171,14 +171,16 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): A durable trainable class wrapped around your trainable. """ + overwrite_name = None if isinstance(trainable, str): trainable_cls = get_trainable_cls(trainable) + overwrite_name = f"Durable{trainable}" else: trainable_cls = trainable if not inspect.isclass(trainable_cls): # Function API - return wrap_function(trainable_cls, durable=True) + return wrap_function(trainable_cls, durable=True, name=overwrite_name) if not issubclass(trainable_cls, Trainable): raise ValueError( @@ -187,8 +189,14 @@ def durable(trainable: Union[str, Type[Trainable], Callable]): f"it does. Got: {type(trainable_cls)}") # else: Class API + + # Class is already durable + + if issubclass(trainable_cls, DurableTrainable): + return trainable_cls + class _WrappedDurableTrainable(DurableTrainable, trainable_cls): - _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \ - else "durable_trainable" + _name = (trainable_cls.__name__ if hasattr(trainable_cls, "__name__") + else "durable_trainable") return _WrappedDurableTrainable diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index ae4235aa8..e4c201806 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -10,6 +10,8 @@ import uuid from functools import partial from numbers import Number +from typing import Any, Callable, Optional + from six.moves import queue from ray.util.debug import log_once @@ -530,7 +532,10 @@ class FunctionRunner(Trainable): pass -def wrap_function(train_func, durable=False, warn=True): +def wrap_function(train_func: Callable[[Any], Any], + durable: bool = False, + warn: bool = True, + name: Optional[str] = None): inherit_from = (FunctionRunner, ) if hasattr(train_func, "__mixins__"): @@ -562,8 +567,8 @@ def wrap_function(train_func, durable=False, warn=True): "arguments to be `func(config, checkpoint_dir=None)`.") class ImplicitFunc(*inherit_from): - _name = train_func.__name__ if hasattr(train_func, "__name__") \ - else "func" + _name = name or (train_func.__name__ + if hasattr(train_func, "__name__") else "func") def _trainable_func(self, config, reporter, checkpoint_dir): if not use_checkpoint and not use_reporter: diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index d83c72717..9f143db42 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -1,5 +1,8 @@ import logging +import uuid + from types import FunctionType +from typing import Optional import ray import ray.cloudpickle as pickle @@ -114,23 +117,25 @@ def check_serializability(key, value): _global_registry.register(TEST, key, value) -def _make_key(category, key): +def _make_key(prefix, category, key): """Generate a binary key for the given category and key. Args: + prefix (str): Prefix category (str): The category of the item key (str): The unique identifier for the item Returns: The key to use for storing a the value. """ - return (b"TuneRegistry:" + category.encode("ascii") + b"/" + - key.encode("ascii")) + return (b"TuneRegistry:" + prefix.encode("ascii") + b":" + + category.encode("ascii") + b"/" + key.encode("ascii")) class _Registry: - def __init__(self): + def __init__(self, prefix: Optional[str] = None): self._to_flush = {} + self._prefix = prefix or uuid.uuid4().hex[:8] def register(self, category, key, value): """Registers the value with the global registry. @@ -148,14 +153,14 @@ class _Registry: def contains(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(category, key)) + value = _internal_kv_get(_make_key(self._prefix, category, key)) return value is not None else: return (category, key) in self._to_flush def get(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(category, key)) + value = _internal_kv_get(_make_key(self._prefix, category, key)) if value is None: raise ValueError( "Registry value for {}/{} doesn't exist.".format( @@ -166,11 +171,12 @@ class _Registry: def flush_values(self): for (category, key), value in self._to_flush.items(): - _internal_kv_put(_make_key(category, key), value, overwrite=True) + _internal_kv_put( + _make_key(self._prefix, category, key), value, overwrite=True) self._to_flush.clear() -_global_registry = _Registry() +_global_registry = _Registry(prefix="global") ray.worker._post_init_hooks.append(_global_registry.flush_values) diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index bb49de900..598b2a2dc 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -226,6 +226,14 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertRaises(TypeError, lambda: register_trainable("foo", A)) self.assertRaises(TypeError, lambda: Experiment("foo", A)) + def testRegisterDurableTrainableTwice(self): + def train(config, reporter): + pass + + register_trainable("foo", train) + register_trainable("foo", tune.durable("foo")) + register_trainable("foo", tune.durable("foo")) + def testTrainableCallable(self): def dummy_fn(config, reporter, steps): reporter(timesteps_total=steps, done=True)