[tune] Prevent errors with retained trainables in global registry (#19184)

This PR fixes #19183 by introducing three improvements:

String trainables are prefixed with Durable, e.g. DurablePPO
Durable trainables cannot be wrapped twice with tune.durable()
MRO resolution in _WrappedDurableTrainables indicates we already have a DurableTrainable - thus we catch this with a try/except block
This commit is contained in:
Kai Fricke 2021-10-08 01:17:01 +01:00 committed by GitHub
parent ca731d7c86
commit 8d89e2d546
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 14 deletions

View file

@ -171,14 +171,16 @@ def durable(trainable: Union[str, Type[Trainable], Callable]):
A durable trainable class wrapped around your trainable.
"""
overwrite_name = None
if isinstance(trainable, str):
trainable_cls = get_trainable_cls(trainable)
overwrite_name = f"Durable{trainable}"
else:
trainable_cls = trainable
if not inspect.isclass(trainable_cls):
# Function API
return wrap_function(trainable_cls, durable=True)
return wrap_function(trainable_cls, durable=True, name=overwrite_name)
if not issubclass(trainable_cls, Trainable):
raise ValueError(
@ -187,8 +189,14 @@ def durable(trainable: Union[str, Type[Trainable], Callable]):
f"it does. Got: {type(trainable_cls)}")
# else: Class API
# Class is already durable
if issubclass(trainable_cls, DurableTrainable):
return trainable_cls
class _WrappedDurableTrainable(DurableTrainable, trainable_cls):
_name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
else "durable_trainable"
_name = (trainable_cls.__name__ if hasattr(trainable_cls, "__name__")
else "durable_trainable")
return _WrappedDurableTrainable

View file

@ -10,6 +10,8 @@ import uuid
from functools import partial
from numbers import Number
from typing import Any, Callable, Optional
from six.moves import queue
from ray.util.debug import log_once
@ -530,7 +532,10 @@ class FunctionRunner(Trainable):
pass
def wrap_function(train_func, durable=False, warn=True):
def wrap_function(train_func: Callable[[Any], Any],
durable: bool = False,
warn: bool = True,
name: Optional[str] = None):
inherit_from = (FunctionRunner, )
if hasattr(train_func, "__mixins__"):
@ -562,8 +567,8 @@ def wrap_function(train_func, durable=False, warn=True):
"arguments to be `func(config, checkpoint_dir=None)`.")
class ImplicitFunc(*inherit_from):
_name = train_func.__name__ if hasattr(train_func, "__name__") \
else "func"
_name = name or (train_func.__name__
if hasattr(train_func, "__name__") else "func")
def _trainable_func(self, config, reporter, checkpoint_dir):
if not use_checkpoint and not use_reporter:

View file

@ -1,5 +1,8 @@
import logging
import uuid
from types import FunctionType
from typing import Optional
import ray
import ray.cloudpickle as pickle
@ -114,23 +117,25 @@ def check_serializability(key, value):
_global_registry.register(TEST, key, value)
def _make_key(category, key):
def _make_key(prefix, category, key):
"""Generate a binary key for the given category and key.
Args:
prefix (str): Prefix
category (str): The category of the item
key (str): The unique identifier for the item
Returns:
The key to use for storing a the value.
"""
return (b"TuneRegistry:" + category.encode("ascii") + b"/" +
key.encode("ascii"))
return (b"TuneRegistry:" + prefix.encode("ascii") + b":" +
category.encode("ascii") + b"/" + key.encode("ascii"))
class _Registry:
def __init__(self):
def __init__(self, prefix: Optional[str] = None):
self._to_flush = {}
self._prefix = prefix or uuid.uuid4().hex[:8]
def register(self, category, key, value):
"""Registers the value with the global registry.
@ -148,14 +153,14 @@ class _Registry:
def contains(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(category, key))
value = _internal_kv_get(_make_key(self._prefix, category, key))
return value is not None
else:
return (category, key) in self._to_flush
def get(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(category, key))
value = _internal_kv_get(_make_key(self._prefix, category, key))
if value is None:
raise ValueError(
"Registry value for {}/{} doesn't exist.".format(
@ -166,11 +171,12 @@ class _Registry:
def flush_values(self):
for (category, key), value in self._to_flush.items():
_internal_kv_put(_make_key(category, key), value, overwrite=True)
_internal_kv_put(
_make_key(self._prefix, category, key), value, overwrite=True)
self._to_flush.clear()
_global_registry = _Registry()
_global_registry = _Registry(prefix="global")
ray.worker._post_init_hooks.append(_global_registry.flush_values)

View file

@ -226,6 +226,14 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
self.assertRaises(TypeError, lambda: Experiment("foo", A))
def testRegisterDurableTrainableTwice(self):
def train(config, reporter):
pass
register_trainable("foo", train)
register_trainable("foo", tune.durable("foo"))
register_trainable("foo", tune.durable("foo"))
def testTrainableCallable(self):
def dummy_fn(config, reporter, steps):
reporter(timesteps_total=steps, done=True)