mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[tune] Prevent errors with retained trainables in global registry (#19184)
This PR fixes #19183 by introducing three improvements: String trainables are prefixed with Durable, e.g. DurablePPO Durable trainables cannot be wrapped twice with tune.durable() MRO resolution in _WrappedDurableTrainables indicates we already have a DurableTrainable - thus we catch this with a try/except block
This commit is contained in:
parent
ca731d7c86
commit
8d89e2d546
4 changed files with 41 additions and 14 deletions
|
@ -171,14 +171,16 @@ def durable(trainable: Union[str, Type[Trainable], Callable]):
|
|||
A durable trainable class wrapped around your trainable.
|
||||
|
||||
"""
|
||||
overwrite_name = None
|
||||
if isinstance(trainable, str):
|
||||
trainable_cls = get_trainable_cls(trainable)
|
||||
overwrite_name = f"Durable{trainable}"
|
||||
else:
|
||||
trainable_cls = trainable
|
||||
|
||||
if not inspect.isclass(trainable_cls):
|
||||
# Function API
|
||||
return wrap_function(trainable_cls, durable=True)
|
||||
return wrap_function(trainable_cls, durable=True, name=overwrite_name)
|
||||
|
||||
if not issubclass(trainable_cls, Trainable):
|
||||
raise ValueError(
|
||||
|
@ -187,8 +189,14 @@ def durable(trainable: Union[str, Type[Trainable], Callable]):
|
|||
f"it does. Got: {type(trainable_cls)}")
|
||||
|
||||
# else: Class API
|
||||
|
||||
# Class is already durable
|
||||
|
||||
if issubclass(trainable_cls, DurableTrainable):
|
||||
return trainable_cls
|
||||
|
||||
class _WrappedDurableTrainable(DurableTrainable, trainable_cls):
|
||||
_name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
|
||||
else "durable_trainable"
|
||||
_name = (trainable_cls.__name__ if hasattr(trainable_cls, "__name__")
|
||||
else "durable_trainable")
|
||||
|
||||
return _WrappedDurableTrainable
|
||||
|
|
|
@ -10,6 +10,8 @@ import uuid
|
|||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
@ -530,7 +532,10 @@ class FunctionRunner(Trainable):
|
|||
pass
|
||||
|
||||
|
||||
def wrap_function(train_func, durable=False, warn=True):
|
||||
def wrap_function(train_func: Callable[[Any], Any],
|
||||
durable: bool = False,
|
||||
warn: bool = True,
|
||||
name: Optional[str] = None):
|
||||
inherit_from = (FunctionRunner, )
|
||||
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
|
@ -562,8 +567,8 @@ def wrap_function(train_func, durable=False, warn=True):
|
|||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
|
||||
class ImplicitFunc(*inherit_from):
|
||||
_name = train_func.__name__ if hasattr(train_func, "__name__") \
|
||||
else "func"
|
||||
_name = name or (train_func.__name__
|
||||
if hasattr(train_func, "__name__") else "func")
|
||||
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
if not use_checkpoint and not use_reporter:
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from types import FunctionType
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
|
@ -114,23 +117,25 @@ def check_serializability(key, value):
|
|||
_global_registry.register(TEST, key, value)
|
||||
|
||||
|
||||
def _make_key(category, key):
|
||||
def _make_key(prefix, category, key):
|
||||
"""Generate a binary key for the given category and key.
|
||||
|
||||
Args:
|
||||
prefix (str): Prefix
|
||||
category (str): The category of the item
|
||||
key (str): The unique identifier for the item
|
||||
|
||||
Returns:
|
||||
The key to use for storing a the value.
|
||||
"""
|
||||
return (b"TuneRegistry:" + category.encode("ascii") + b"/" +
|
||||
key.encode("ascii"))
|
||||
return (b"TuneRegistry:" + prefix.encode("ascii") + b":" +
|
||||
category.encode("ascii") + b"/" + key.encode("ascii"))
|
||||
|
||||
|
||||
class _Registry:
|
||||
def __init__(self):
|
||||
def __init__(self, prefix: Optional[str] = None):
|
||||
self._to_flush = {}
|
||||
self._prefix = prefix or uuid.uuid4().hex[:8]
|
||||
|
||||
def register(self, category, key, value):
|
||||
"""Registers the value with the global registry.
|
||||
|
@ -148,14 +153,14 @@ class _Registry:
|
|||
|
||||
def contains(self, category, key):
|
||||
if _internal_kv_initialized():
|
||||
value = _internal_kv_get(_make_key(category, key))
|
||||
value = _internal_kv_get(_make_key(self._prefix, category, key))
|
||||
return value is not None
|
||||
else:
|
||||
return (category, key) in self._to_flush
|
||||
|
||||
def get(self, category, key):
|
||||
if _internal_kv_initialized():
|
||||
value = _internal_kv_get(_make_key(category, key))
|
||||
value = _internal_kv_get(_make_key(self._prefix, category, key))
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
"Registry value for {}/{} doesn't exist.".format(
|
||||
|
@ -166,11 +171,12 @@ class _Registry:
|
|||
|
||||
def flush_values(self):
|
||||
for (category, key), value in self._to_flush.items():
|
||||
_internal_kv_put(_make_key(category, key), value, overwrite=True)
|
||||
_internal_kv_put(
|
||||
_make_key(self._prefix, category, key), value, overwrite=True)
|
||||
self._to_flush.clear()
|
||||
|
||||
|
||||
_global_registry = _Registry()
|
||||
_global_registry = _Registry(prefix="global")
|
||||
ray.worker._post_init_hooks.append(_global_registry.flush_values)
|
||||
|
||||
|
||||
|
|
|
@ -226,6 +226,14 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
|
||||
self.assertRaises(TypeError, lambda: Experiment("foo", A))
|
||||
|
||||
def testRegisterDurableTrainableTwice(self):
|
||||
def train(config, reporter):
|
||||
pass
|
||||
|
||||
register_trainable("foo", train)
|
||||
register_trainable("foo", tune.durable("foo"))
|
||||
register_trainable("foo", tune.durable("foo"))
|
||||
|
||||
def testTrainableCallable(self):
|
||||
def dummy_fn(config, reporter, steps):
|
||||
reporter(timesteps_total=steps, done=True)
|
||||
|
|
Loading…
Add table
Reference in a new issue