mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Enable actor methods to be decorated on the caller side also and get postprocessors. (#4732)
* Allow decorating ray actor methods. * Add test. * Add get postprocessors. * Improve documentation. * Make it work for remote functions. * Temporary fix.
This commit is contained in:
parent
897b35ce36
commit
d81e71e297
5 changed files with 186 additions and 43 deletions
|
@ -109,10 +109,36 @@ def method(*args, **kwargs):
|
|||
# Create objects to wrap method invocations. This is done so that we can
|
||||
# invoke methods with actor.method.remote() instead of actor.method().
|
||||
class ActorMethod(object):
|
||||
def __init__(self, actor, method_name, num_return_vals):
|
||||
"""A class used to invoke an actor method.
|
||||
|
||||
Note: This class is instantiated only while the actor method is being
|
||||
invoked (so that it doesn't keep a reference to the actor handle and
|
||||
prevent it from going out of scope).
|
||||
|
||||
Attributes:
|
||||
_actor: A handle to the actor.
|
||||
_method_name: The name of the actor method.
|
||||
_num_return_vals: The default number of return values that the method
|
||||
invocation should return.
|
||||
_decorator: An optional decorator that should be applied to the actor
|
||||
method invocation (as opposed to the actor method execution) before
|
||||
invoking the method. The decorator must return a function that
|
||||
takes in two arguments ("args" and "kwargs"). In most cases, it
|
||||
should call the function that was passed into the decorator and
|
||||
return the resulting ObjectIDs. For an example, see
|
||||
"test_decorated_method" in "python/ray/tests/test_actor.py".
|
||||
"""
|
||||
|
||||
def __init__(self, actor, method_name, num_return_vals, decorator=None):
|
||||
self._actor = actor
|
||||
self._method_name = method_name
|
||||
self._num_return_vals = num_return_vals
|
||||
# This is a decorator that is used to wrap the function invocation (as
|
||||
# opposed to the function execution). The decorator must return a
|
||||
# function that takes in two arguments ("args" and "kwargs"). In most
|
||||
# cases, it should call the function that was passed into the decorator
|
||||
# and return the resulting ObjectIDs.
|
||||
self._decorator = decorator
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Actor methods cannot be called directly. Instead "
|
||||
|
@ -131,11 +157,18 @@ class ActorMethod(object):
|
|||
if num_return_vals is None:
|
||||
num_return_vals = self._num_return_vals
|
||||
|
||||
return self._actor._actor_method_call(
|
||||
self._method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
num_return_vals=num_return_vals)
|
||||
def invocation(args, kwargs):
|
||||
return self._actor._actor_method_call(
|
||||
self._method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
num_return_vals=num_return_vals)
|
||||
|
||||
# Apply the decorator if there is one.
|
||||
if self._decorator is not None:
|
||||
invocation = self._decorator(invocation)
|
||||
|
||||
return invocation(args, kwargs)
|
||||
|
||||
|
||||
class ActorClass(object):
|
||||
|
@ -157,6 +190,10 @@ class ActorClass(object):
|
|||
_exported: True if the actor class has been exported and false
|
||||
otherwise.
|
||||
_actor_methods: The actor methods.
|
||||
_method_decorators: Optional decorators that should be applied to the
|
||||
method invocation function before invoking the actor methods. These
|
||||
can be set by attaching the attribute
|
||||
"__ray_invocation_decorator__" to the actor method.
|
||||
_method_signatures: The signatures of the methods.
|
||||
_actor_method_names: The names of the actor methods.
|
||||
_actor_method_num_return_vals: The default number of return values for
|
||||
|
@ -196,6 +233,7 @@ class ActorClass(object):
|
|||
# Extract the signatures of each of the methods. This will be used
|
||||
# to catch some errors if the methods are called with inappropriate
|
||||
# arguments.
|
||||
self._method_decorators = {}
|
||||
self._method_signatures = {}
|
||||
self._actor_method_num_return_vals = {}
|
||||
for method_name, method in self._actor_methods:
|
||||
|
@ -214,6 +252,10 @@ class ActorClass(object):
|
|||
self._actor_method_num_return_vals[method_name] = (
|
||||
ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS)
|
||||
|
||||
if hasattr(method, "__ray_invocation_decorator__"):
|
||||
self._method_decorators[method_name] = (
|
||||
method.__ray_invocation_decorator__)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Actors methods cannot be instantiated directly. "
|
||||
"Instead of running '{}()', try '{}.remote()'.".format(
|
||||
|
@ -337,9 +379,9 @@ class ActorClass(object):
|
|||
|
||||
actor_handle = ActorHandle(
|
||||
actor_id, self._modified_class.__module__, self._class_name,
|
||||
actor_cursor, self._actor_method_names, self._method_signatures,
|
||||
self._actor_method_num_return_vals, actor_cursor, actor_method_cpu,
|
||||
worker.task_driver_id)
|
||||
actor_cursor, self._actor_method_names, self._method_decorators,
|
||||
self._method_signatures, self._actor_method_num_return_vals,
|
||||
actor_cursor, actor_method_cpu, worker.task_driver_id)
|
||||
# We increment the actor counter by 1 to account for the actor creation
|
||||
# task.
|
||||
actor_handle._ray_actor_counter += 1
|
||||
|
@ -381,6 +423,10 @@ class ActorHandle(object):
|
|||
_ray_actor_counter: The number of actor method invocations that we've
|
||||
called so far.
|
||||
_ray_actor_method_names: The names of the actor methods.
|
||||
_ray_method_decorators: Optional decorators for the function
|
||||
invocation. This can be used to change the behavior on the
|
||||
invocation side, whereas a regular decorator can be used to change
|
||||
the behavior on the execution side.
|
||||
_ray_method_signatures: The signatures of the actor methods.
|
||||
_ray_method_num_return_vals: The default number of return values for
|
||||
each method.
|
||||
|
@ -407,6 +453,7 @@ class ActorHandle(object):
|
|||
class_name,
|
||||
actor_cursor,
|
||||
actor_method_names,
|
||||
method_decorators,
|
||||
method_signatures,
|
||||
method_num_return_vals,
|
||||
actor_creation_dummy_object_id,
|
||||
|
@ -428,6 +475,7 @@ class ActorHandle(object):
|
|||
self._ray_actor_cursor = actor_cursor
|
||||
self._ray_actor_counter = 0
|
||||
self._ray_actor_method_names = actor_method_names
|
||||
self._ray_method_decorators = method_decorators
|
||||
self._ray_method_signatures = method_signatures
|
||||
self._ray_method_num_return_vals = method_num_return_vals
|
||||
self._ray_class_name = class_name
|
||||
|
@ -530,8 +578,11 @@ class ActorHandle(object):
|
|||
# this was causing cyclic references which were prevent
|
||||
# object deallocation from behaving in a predictable
|
||||
# manner.
|
||||
return ActorMethod(self, attr,
|
||||
self._ray_method_num_return_vals[attr])
|
||||
return ActorMethod(
|
||||
self,
|
||||
attr,
|
||||
self._ray_method_num_return_vals[attr],
|
||||
decorator=self._ray_method_decorators.get(attr))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
@ -600,6 +651,7 @@ class ActorHandle(object):
|
|||
"class_name": self._ray_class_name,
|
||||
"actor_cursor": self._ray_actor_cursor,
|
||||
"actor_method_names": self._ray_actor_method_names,
|
||||
"method_decorators": self._ray_method_decorators,
|
||||
"method_signatures": self._ray_method_signatures,
|
||||
"method_num_return_vals": self._ray_method_num_return_vals,
|
||||
# Actors in local mode don't have dummy objects.
|
||||
|
@ -662,6 +714,7 @@ class ActorHandle(object):
|
|||
state["class_name"],
|
||||
state["actor_cursor"],
|
||||
state["actor_method_names"],
|
||||
state["method_decorators"],
|
||||
state["method_signatures"],
|
||||
state["method_num_return_vals"],
|
||||
state["actor_creation_dummy_object_id"],
|
||||
|
|
|
@ -35,6 +35,13 @@ class RemoteFunction(object):
|
|||
of this remote function.
|
||||
_max_calls: The number of times a worker can execute this function
|
||||
before executing.
|
||||
_decorator: An optional decorator that should be applied to the remote
|
||||
function invocation (as opposed to the function execution) before
|
||||
invoking the function. The decorator must return a function that
|
||||
takes in two arguments ("args" and "kwargs"). In most cases, it
|
||||
should call the function that was passed into the decorator and
|
||||
return the resulting ObjectIDs. For an example, see
|
||||
"test_decorated_function" in "python/ray/tests/test_basic.py".
|
||||
_function_signature: The function signature.
|
||||
"""
|
||||
|
||||
|
@ -52,6 +59,8 @@ class RemoteFunction(object):
|
|||
num_return_vals is None else num_return_vals)
|
||||
self._max_calls = (DEFAULT_REMOTE_FUNCTION_MAX_CALLS
|
||||
if max_calls is None else max_calls)
|
||||
self._decorator = getattr(function, "__ray_invocation_decorator__",
|
||||
None)
|
||||
|
||||
ray.signature.check_signature_supported(self._function)
|
||||
self._function_signature = ray.signature.extract_signature(
|
||||
|
@ -108,8 +117,6 @@ class RemoteFunction(object):
|
|||
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
args = [] if args is None else args
|
||||
args = ray.signature.extend_args(self._function_signature, args,
|
||||
kwargs)
|
||||
|
||||
if num_return_vals is None:
|
||||
num_return_vals = self._num_return_vals
|
||||
|
@ -117,19 +124,29 @@ class RemoteFunction(object):
|
|||
resources = ray.utils.resources_from_resource_arguments(
|
||||
self._num_cpus, self._num_gpus, self._resources, num_cpus,
|
||||
num_gpus, resources)
|
||||
if worker.mode == ray.worker.LOCAL_MODE:
|
||||
# In LOCAL_MODE, remote calls simply execute the function.
|
||||
# We copy the arguments to prevent the function call from
|
||||
# mutating them and to match the usual behavior of
|
||||
# immutable remote objects.
|
||||
result = self._function(*copy.deepcopy(args))
|
||||
return result
|
||||
object_ids = worker.submit_task(
|
||||
self._function_descriptor,
|
||||
args,
|
||||
num_return_vals=num_return_vals,
|
||||
resources=resources)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
elif len(object_ids) > 1:
|
||||
return object_ids
|
||||
|
||||
def invocation(args, kwargs):
|
||||
args = ray.signature.extend_args(self._function_signature, args,
|
||||
kwargs)
|
||||
|
||||
if worker.mode == ray.worker.LOCAL_MODE:
|
||||
# In LOCAL_MODE, remote calls simply execute the function.
|
||||
# We copy the arguments to prevent the function call from
|
||||
# mutating them and to match the usual behavior of
|
||||
# immutable remote objects.
|
||||
result = self._function(*copy.deepcopy(args))
|
||||
return result
|
||||
object_ids = worker.submit_task(
|
||||
self._function_descriptor,
|
||||
args,
|
||||
num_return_vals=num_return_vals,
|
||||
resources=resources)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
elif len(object_ids) > 1:
|
||||
return object_ids
|
||||
|
||||
if self._decorator is not None:
|
||||
invocation = self._decorator(invocation)
|
||||
|
||||
return invocation(args, kwargs)
|
||||
|
|
|
@ -2576,3 +2576,35 @@ def test_init_exception_in_checkpointable_actor(ray_start_regular,
|
|||
errors = relevant_errors(ray_constants.TASK_PUSH_ERROR)
|
||||
assert len(errors) == 2
|
||||
assert error_message1 in errors[1]["message"]
|
||||
|
||||
|
||||
def test_decorated_method(ray_start_regular):
|
||||
def method_invocation_decorator(f):
|
||||
def new_f_invocation(args, kwargs):
|
||||
# Split one argument into two. Return th kwargs without passing
|
||||
# them into the actor.
|
||||
return f([args[0], args[0]], {}), kwargs
|
||||
|
||||
return new_f_invocation
|
||||
|
||||
def method_execution_decorator(f):
|
||||
def new_f_execution(self, b, c):
|
||||
# Turn two arguments into one.
|
||||
return f(self, b + c)
|
||||
|
||||
new_f_execution.__ray_invocation_decorator__ = (
|
||||
method_invocation_decorator)
|
||||
return new_f_execution
|
||||
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
@method_execution_decorator
|
||||
def decorated_method(self, x):
|
||||
return x + 1
|
||||
|
||||
a = Actor.remote()
|
||||
|
||||
object_id, extra = a.decorated_method.remote(3, kwarg=3)
|
||||
assert isinstance(object_id, ray.ObjectID)
|
||||
assert extra == {"kwarg": 3}
|
||||
assert ray.get(object_id) == 7 # 2 * 3 + 1
|
||||
|
|
|
@ -2892,3 +2892,32 @@ def test_redis_lru_with_set(ray_start_object_store_memory):
|
|||
|
||||
# Now evict the object from the object store.
|
||||
ray.put(x) # This should not crash.
|
||||
|
||||
|
||||
def test_decorated_function(ray_start_regular):
|
||||
def function_invocation_decorator(f):
|
||||
def new_f(args, kwargs):
|
||||
# Reverse the arguments.
|
||||
return f(args[::-1], {"d": 5}), kwargs
|
||||
|
||||
return new_f
|
||||
|
||||
def f(a, b, c, d=None):
|
||||
return a, b, c, d
|
||||
|
||||
f.__ray_invocation_decorator__ = function_invocation_decorator
|
||||
f = ray.remote(f)
|
||||
|
||||
result_id, kwargs = f.remote(1, 2, 3, d=4)
|
||||
assert kwargs == {"d": 4}
|
||||
assert ray.get(result_id) == (3, 2, 1, 5)
|
||||
|
||||
|
||||
def test_get_postprocess(ray_start_regular):
|
||||
def get_postprocessor(object_ids, values):
|
||||
return [value for value in values if value > 0]
|
||||
|
||||
ray.worker.global_worker._post_get_hooks.append(get_postprocessor)
|
||||
|
||||
assert ray.get(
|
||||
[ray.put(i) for i in [0, 1, 3, 5, -1, -3, 4]]) == [1, 3, 5, 4]
|
||||
|
|
|
@ -156,6 +156,9 @@ class Worker(object):
|
|||
# increment every time when `ray.shutdown` is called.
|
||||
self._session_index = 0
|
||||
self._current_task = None
|
||||
# Functions to run to process the values returned by ray.get. Each
|
||||
# postprocessor must take two arguments ("object_ids", and "values").
|
||||
self._post_get_hooks = []
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
|
@ -1455,7 +1458,7 @@ def init(redis_address=None,
|
|||
return _global_node.address_info
|
||||
|
||||
|
||||
# Functions to run as callback after a successful ray init
|
||||
# Functions to run as callback after a successful ray init.
|
||||
_post_init_hooks = []
|
||||
|
||||
|
||||
|
@ -1493,7 +1496,10 @@ def shutdown(exiting_interpreter=False):
|
|||
_global_node.kill_all_processes(check_alive=False, allow_graceful=True)
|
||||
_global_node = None
|
||||
|
||||
# TODO(rkn): Instead of manually reseting some of the worker fields, we
|
||||
# should simply set "global_worker" to equal "None" or something like that.
|
||||
global_worker.set_mode(None)
|
||||
global_worker._post_get_hooks = []
|
||||
|
||||
|
||||
atexit.register(shutdown, True)
|
||||
|
@ -2175,23 +2181,29 @@ def get(object_ids):
|
|||
# In LOCAL_MODE, ray.get is the identity operation (the input will
|
||||
# actually be a value not an objectid).
|
||||
return object_ids
|
||||
|
||||
is_individual_id = isinstance(object_ids, ray.ObjectID)
|
||||
if is_individual_id:
|
||||
object_ids = [object_ids]
|
||||
|
||||
if not isinstance(object_ids, list):
|
||||
raise ValueError("'object_ids' must either by an object ID "
|
||||
"or a list of object IDs.")
|
||||
|
||||
global last_task_error_raise_time
|
||||
if isinstance(object_ids, list):
|
||||
values = worker.get_object(object_ids)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayError):
|
||||
last_task_error_raise_time = time.time()
|
||||
raise value
|
||||
return values
|
||||
else:
|
||||
value = worker.get_object([object_ids])[0]
|
||||
values = worker.get_object(object_ids)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayError):
|
||||
# If the result is a RayError, then the task that created
|
||||
# this object failed, and we should propagate the error message
|
||||
# here.
|
||||
last_task_error_raise_time = time.time()
|
||||
raise value
|
||||
return value
|
||||
|
||||
# Run post processors.
|
||||
for post_processor in worker._post_get_hooks:
|
||||
values = post_processor(object_ids, values)
|
||||
|
||||
if is_individual_id:
|
||||
values = values[0]
|
||||
return values
|
||||
|
||||
|
||||
def put(value):
|
||||
|
|
Loading…
Add table
Reference in a new issue