Enable actor methods to be decorated on the caller side also and get postprocessors. (#4732)

* Allow decorating ray actor methods.

* Add test.

* Add get postprocessors.

* Improve documentation.

* Make it work for remote functions.

* Temporary fix.
This commit is contained in:
Robert Nishihara 2019-05-04 11:53:47 -07:00 committed by Philipp Moritz
parent 897b35ce36
commit d81e71e297
5 changed files with 186 additions and 43 deletions

View file

@ -109,10 +109,36 @@ def method(*args, **kwargs):
# Create objects to wrap method invocations. This is done so that we can
# invoke methods with actor.method.remote() instead of actor.method().
class ActorMethod(object):
def __init__(self, actor, method_name, num_return_vals):
"""A class used to invoke an actor method.
Note: This class is instantiated only while the actor method is being
invoked (so that it doesn't keep a reference to the actor handle and
prevent it from going out of scope).
Attributes:
_actor: A handle to the actor.
_method_name: The name of the actor method.
_num_return_vals: The default number of return values that the method
invocation should return.
_decorator: An optional decorator that should be applied to the actor
method invocation (as opposed to the actor method execution) before
invoking the method. The decorator must return a function that
takes in two arguments ("args" and "kwargs"). In most cases, it
should call the function that was passed into the decorator and
return the resulting ObjectIDs. For an example, see
"test_decorated_method" in "python/ray/tests/test_actor.py".
"""
def __init__(self, actor, method_name, num_return_vals, decorator=None):
self._actor = actor
self._method_name = method_name
self._num_return_vals = num_return_vals
# This is a decorator that is used to wrap the function invocation (as
# opposed to the function execution). The decorator must return a
# function that takes in two arguments ("args" and "kwargs"). In most
# cases, it should call the function that was passed into the decorator
# and return the resulting ObjectIDs.
self._decorator = decorator
def __call__(self, *args, **kwargs):
raise Exception("Actor methods cannot be called directly. Instead "
@ -131,11 +157,18 @@ class ActorMethod(object):
if num_return_vals is None:
num_return_vals = self._num_return_vals
return self._actor._actor_method_call(
self._method_name,
args=args,
kwargs=kwargs,
num_return_vals=num_return_vals)
def invocation(args, kwargs):
return self._actor._actor_method_call(
self._method_name,
args=args,
kwargs=kwargs,
num_return_vals=num_return_vals)
# Apply the decorator if there is one.
if self._decorator is not None:
invocation = self._decorator(invocation)
return invocation(args, kwargs)
class ActorClass(object):
@ -157,6 +190,10 @@ class ActorClass(object):
_exported: True if the actor class has been exported and false
otherwise.
_actor_methods: The actor methods.
_method_decorators: Optional decorators that should be applied to the
method invocation function before invoking the actor methods. These
can be set by attaching the attribute
"__ray_invocation_decorator__" to the actor method.
_method_signatures: The signatures of the methods.
_actor_method_names: The names of the actor methods.
_actor_method_num_return_vals: The default number of return values for
@ -196,6 +233,7 @@ class ActorClass(object):
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
self._method_decorators = {}
self._method_signatures = {}
self._actor_method_num_return_vals = {}
for method_name, method in self._actor_methods:
@ -214,6 +252,10 @@ class ActorClass(object):
self._actor_method_num_return_vals[method_name] = (
ray_constants.DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS)
if hasattr(method, "__ray_invocation_decorator__"):
self._method_decorators[method_name] = (
method.__ray_invocation_decorator__)
def __call__(self, *args, **kwargs):
raise Exception("Actors methods cannot be instantiated directly. "
"Instead of running '{}()', try '{}.remote()'.".format(
@ -337,9 +379,9 @@ class ActorClass(object):
actor_handle = ActorHandle(
actor_id, self._modified_class.__module__, self._class_name,
actor_cursor, self._actor_method_names, self._method_signatures,
self._actor_method_num_return_vals, actor_cursor, actor_method_cpu,
worker.task_driver_id)
actor_cursor, self._actor_method_names, self._method_decorators,
self._method_signatures, self._actor_method_num_return_vals,
actor_cursor, actor_method_cpu, worker.task_driver_id)
# We increment the actor counter by 1 to account for the actor creation
# task.
actor_handle._ray_actor_counter += 1
@ -381,6 +423,10 @@ class ActorHandle(object):
_ray_actor_counter: The number of actor method invocations that we've
called so far.
_ray_actor_method_names: The names of the actor methods.
_ray_method_decorators: Optional decorators for the function
invocation. This can be used to change the behavior on the
invocation side, whereas a regular decorator can be used to change
the behavior on the execution side.
_ray_method_signatures: The signatures of the actor methods.
_ray_method_num_return_vals: The default number of return values for
each method.
@ -407,6 +453,7 @@ class ActorHandle(object):
class_name,
actor_cursor,
actor_method_names,
method_decorators,
method_signatures,
method_num_return_vals,
actor_creation_dummy_object_id,
@ -428,6 +475,7 @@ class ActorHandle(object):
self._ray_actor_cursor = actor_cursor
self._ray_actor_counter = 0
self._ray_actor_method_names = actor_method_names
self._ray_method_decorators = method_decorators
self._ray_method_signatures = method_signatures
self._ray_method_num_return_vals = method_num_return_vals
self._ray_class_name = class_name
@ -530,8 +578,11 @@ class ActorHandle(object):
# this was causing cyclic references which were prevent
# object deallocation from behaving in a predictable
# manner.
return ActorMethod(self, attr,
self._ray_method_num_return_vals[attr])
return ActorMethod(
self,
attr,
self._ray_method_num_return_vals[attr],
decorator=self._ray_method_decorators.get(attr))
except AttributeError:
pass
@ -600,6 +651,7 @@ class ActorHandle(object):
"class_name": self._ray_class_name,
"actor_cursor": self._ray_actor_cursor,
"actor_method_names": self._ray_actor_method_names,
"method_decorators": self._ray_method_decorators,
"method_signatures": self._ray_method_signatures,
"method_num_return_vals": self._ray_method_num_return_vals,
# Actors in local mode don't have dummy objects.
@ -662,6 +714,7 @@ class ActorHandle(object):
state["class_name"],
state["actor_cursor"],
state["actor_method_names"],
state["method_decorators"],
state["method_signatures"],
state["method_num_return_vals"],
state["actor_creation_dummy_object_id"],

View file

@ -35,6 +35,13 @@ class RemoteFunction(object):
of this remote function.
_max_calls: The number of times a worker can execute this function
before executing.
_decorator: An optional decorator that should be applied to the remote
function invocation (as opposed to the function execution) before
invoking the function. The decorator must return a function that
takes in two arguments ("args" and "kwargs"). In most cases, it
should call the function that was passed into the decorator and
return the resulting ObjectIDs. For an example, see
"test_decorated_function" in "python/ray/tests/test_basic.py".
_function_signature: The function signature.
"""
@ -52,6 +59,8 @@ class RemoteFunction(object):
num_return_vals is None else num_return_vals)
self._max_calls = (DEFAULT_REMOTE_FUNCTION_MAX_CALLS
if max_calls is None else max_calls)
self._decorator = getattr(function, "__ray_invocation_decorator__",
None)
ray.signature.check_signature_supported(self._function)
self._function_signature = ray.signature.extract_signature(
@ -108,8 +117,6 @@ class RemoteFunction(object):
kwargs = {} if kwargs is None else kwargs
args = [] if args is None else args
args = ray.signature.extend_args(self._function_signature, args,
kwargs)
if num_return_vals is None:
num_return_vals = self._num_return_vals
@ -117,19 +124,29 @@ class RemoteFunction(object):
resources = ray.utils.resources_from_resource_arguments(
self._num_cpus, self._num_gpus, self._resources, num_cpus,
num_gpus, resources)
if worker.mode == ray.worker.LOCAL_MODE:
# In LOCAL_MODE, remote calls simply execute the function.
# We copy the arguments to prevent the function call from
# mutating them and to match the usual behavior of
# immutable remote objects.
result = self._function(*copy.deepcopy(args))
return result
object_ids = worker.submit_task(
self._function_descriptor,
args,
num_return_vals=num_return_vals,
resources=resources)
if len(object_ids) == 1:
return object_ids[0]
elif len(object_ids) > 1:
return object_ids
def invocation(args, kwargs):
args = ray.signature.extend_args(self._function_signature, args,
kwargs)
if worker.mode == ray.worker.LOCAL_MODE:
# In LOCAL_MODE, remote calls simply execute the function.
# We copy the arguments to prevent the function call from
# mutating them and to match the usual behavior of
# immutable remote objects.
result = self._function(*copy.deepcopy(args))
return result
object_ids = worker.submit_task(
self._function_descriptor,
args,
num_return_vals=num_return_vals,
resources=resources)
if len(object_ids) == 1:
return object_ids[0]
elif len(object_ids) > 1:
return object_ids
if self._decorator is not None:
invocation = self._decorator(invocation)
return invocation(args, kwargs)

View file

@ -2576,3 +2576,35 @@ def test_init_exception_in_checkpointable_actor(ray_start_regular,
errors = relevant_errors(ray_constants.TASK_PUSH_ERROR)
assert len(errors) == 2
assert error_message1 in errors[1]["message"]
def test_decorated_method(ray_start_regular):
def method_invocation_decorator(f):
def new_f_invocation(args, kwargs):
# Split one argument into two. Return th kwargs without passing
# them into the actor.
return f([args[0], args[0]], {}), kwargs
return new_f_invocation
def method_execution_decorator(f):
def new_f_execution(self, b, c):
# Turn two arguments into one.
return f(self, b + c)
new_f_execution.__ray_invocation_decorator__ = (
method_invocation_decorator)
return new_f_execution
@ray.remote
class Actor(object):
@method_execution_decorator
def decorated_method(self, x):
return x + 1
a = Actor.remote()
object_id, extra = a.decorated_method.remote(3, kwarg=3)
assert isinstance(object_id, ray.ObjectID)
assert extra == {"kwarg": 3}
assert ray.get(object_id) == 7 # 2 * 3 + 1

View file

@ -2892,3 +2892,32 @@ def test_redis_lru_with_set(ray_start_object_store_memory):
# Now evict the object from the object store.
ray.put(x) # This should not crash.
def test_decorated_function(ray_start_regular):
def function_invocation_decorator(f):
def new_f(args, kwargs):
# Reverse the arguments.
return f(args[::-1], {"d": 5}), kwargs
return new_f
def f(a, b, c, d=None):
return a, b, c, d
f.__ray_invocation_decorator__ = function_invocation_decorator
f = ray.remote(f)
result_id, kwargs = f.remote(1, 2, 3, d=4)
assert kwargs == {"d": 4}
assert ray.get(result_id) == (3, 2, 1, 5)
def test_get_postprocess(ray_start_regular):
def get_postprocessor(object_ids, values):
return [value for value in values if value > 0]
ray.worker.global_worker._post_get_hooks.append(get_postprocessor)
assert ray.get(
[ray.put(i) for i in [0, 1, 3, 5, -1, -3, 4]]) == [1, 3, 5, 4]

View file

@ -156,6 +156,9 @@ class Worker(object):
# increment every time when `ray.shutdown` is called.
self._session_index = 0
self._current_task = None
# Functions to run to process the values returned by ray.get. Each
# postprocessor must take two arguments ("object_ids", and "values").
self._post_get_hooks = []
@property
def connected(self):
@ -1455,7 +1458,7 @@ def init(redis_address=None,
return _global_node.address_info
# Functions to run as callback after a successful ray init
# Functions to run as callback after a successful ray init.
_post_init_hooks = []
@ -1493,7 +1496,10 @@ def shutdown(exiting_interpreter=False):
_global_node.kill_all_processes(check_alive=False, allow_graceful=True)
_global_node = None
# TODO(rkn): Instead of manually reseting some of the worker fields, we
# should simply set "global_worker" to equal "None" or something like that.
global_worker.set_mode(None)
global_worker._post_get_hooks = []
atexit.register(shutdown, True)
@ -2175,23 +2181,29 @@ def get(object_ids):
# In LOCAL_MODE, ray.get is the identity operation (the input will
# actually be a value not an objectid).
return object_ids
is_individual_id = isinstance(object_ids, ray.ObjectID)
if is_individual_id:
object_ids = [object_ids]
if not isinstance(object_ids, list):
raise ValueError("'object_ids' must either by an object ID "
"or a list of object IDs.")
global last_task_error_raise_time
if isinstance(object_ids, list):
values = worker.get_object(object_ids)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
raise value
return values
else:
value = worker.get_object([object_ids])[0]
values = worker.get_object(object_ids)
for i, value in enumerate(values):
if isinstance(value, RayError):
# If the result is a RayError, then the task that created
# this object failed, and we should propagate the error message
# here.
last_task_error_raise_time = time.time()
raise value
return value
# Run post processors.
for post_processor in worker._post_get_hooks:
values = post_processor(object_ids, values)
if is_individual_id:
values = values[0]
return values
def put(value):