mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[ray_client] Implement optional arguments to ray.remote() and f.options() (#12985)
This commit is contained in:
parent
11f34f72d8
commit
80f6dd16b2
12 changed files with 336 additions and 90 deletions
|
@ -28,6 +28,7 @@ import sys
|
|||
|
||||
from typing import NamedTuple
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from ray.experimental.client import RayAPIStub
|
||||
|
@ -37,6 +38,7 @@ from ray.experimental.client.common import ClientActorRef
|
|||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import OptionWrapper
|
||||
from ray.experimental.client.common import SelfReferenceSentinel
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
|
@ -52,7 +54,8 @@ else:
|
|||
# the data for an exectuion, with no arguments. Combine the two?
|
||||
PickleStub = NamedTuple("PickleStub",
|
||||
[("type", str), ("client_id", str), ("ref_id", bytes),
|
||||
("name", Optional[str])])
|
||||
("name", Optional[str]),
|
||||
("baseline_options", Optional[Dict])])
|
||||
|
||||
|
||||
class ClientPickler(cloudpickle.CloudPickler):
|
||||
|
@ -67,6 +70,7 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientObjectRef):
|
||||
return PickleStub(
|
||||
|
@ -74,6 +78,7 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=obj.id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientActorHandle):
|
||||
return PickleStub(
|
||||
|
@ -81,6 +86,7 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteFunc):
|
||||
# TODO(barakmich): This is going to have trouble with mutually
|
||||
|
@ -95,12 +101,14 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteFunc",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
baseline_options=obj._options,
|
||||
)
|
||||
elif isinstance(obj, ClientActorClass):
|
||||
# TODO(barakmich): Mutual recursion, as above.
|
||||
|
@ -112,12 +120,14 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteActor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
baseline_options=obj._options,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteMethod):
|
||||
return PickleStub(
|
||||
|
@ -125,7 +135,11 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=obj.actor_handle.actor_ref.id,
|
||||
name=obj.method_name,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, OptionWrapper):
|
||||
raise NotImplementedError(
|
||||
"Sending a partial option is unimplemented")
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,21 @@
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from ray.experimental.client.options import validate_options
|
||||
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id: bytes):
|
||||
self.id = None
|
||||
if not isinstance(id, bytes):
|
||||
raise TypeError("ClientRefs must be created with bytes IDs")
|
||||
self.id: bytes = id
|
||||
ray.call_retain(id)
|
||||
|
||||
|
@ -23,7 +35,7 @@ class ClientBaseRef:
|
|||
return hash(self.id)
|
||||
|
||||
def __del__(self):
|
||||
if ray.is_connected():
|
||||
if ray.is_connected() and self.id is not None:
|
||||
ray.call_release(self.id)
|
||||
|
||||
|
||||
|
@ -52,33 +64,42 @@ class ClientRemoteFunc(ClientStub):
|
|||
_ref: The ClientObjectRef of the pickled code of the function, _func
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
def __init__(self, f, options=None):
|
||||
self._lock = threading.Lock()
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self._ref = None
|
||||
self._options = validate_options(options)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def _ensure_ref(self):
|
||||
if self._ref is None:
|
||||
# While calling ray.put() on our function, if
|
||||
# our function is recursive, it will attempt to
|
||||
# encode the ClientRemoteFunc -- itself -- and
|
||||
# infinitely recurse on _ensure_ref.
|
||||
#
|
||||
# So we set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self._func)
|
||||
with self._lock:
|
||||
if self._ref is None:
|
||||
# While calling ray.put() on our function, if
|
||||
# our function is recursive, it will attempt to
|
||||
# encode the ClientRemoteFunc -- itself -- and
|
||||
# infinitely recurse on _ensure_ref.
|
||||
#
|
||||
# So we set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self._func)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
self._ensure_ref()
|
||||
|
@ -86,6 +107,7 @@ class ClientRemoteFunc(ClientStub):
|
|||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
set_task_options(task, self._options, "baseline_options")
|
||||
return task
|
||||
|
||||
|
||||
|
@ -100,10 +122,11 @@ class ClientActorClass(ClientStub):
|
|||
_ref: The ClientObjectRef of the pickled `actor_cls`
|
||||
"""
|
||||
|
||||
def __init__(self, actor_cls):
|
||||
def __init__(self, actor_cls, options=None):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
self._ref = None
|
||||
self._options = validate_options(options)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
|
@ -119,8 +142,15 @@ class ClientActorClass(ClientStub):
|
|||
|
||||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||
# Actually instantiate the actor
|
||||
ref_id = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref_id), self)
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
|
||||
|
||||
def options(self, **kwargs):
|
||||
return ActorOptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
|
||||
|
@ -136,6 +166,7 @@ class ClientActorClass(ClientStub):
|
|||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
set_task_options(task, self._options, "baseline_options")
|
||||
return task
|
||||
|
||||
|
||||
|
@ -160,7 +191,8 @@ class ClientActorHandle(ClientStub):
|
|||
self.actor_ref = actor_ref
|
||||
|
||||
def __del__(self) -> None:
|
||||
ray.call_release(self.actor_ref.id)
|
||||
if ray.is_connected():
|
||||
ray.call_release(self.actor_ref.id)
|
||||
|
||||
@property
|
||||
def _actor_id(self):
|
||||
|
@ -193,12 +225,18 @@ class ClientRemoteMethod(ClientStub):
|
|||
f"Use {self._name}.remote() instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
|
||||
self.actor_handle)
|
||||
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
|
@ -207,6 +245,49 @@ class ClientRemoteMethod(ClientStub):
|
|||
return task
|
||||
|
||||
|
||||
class OptionWrapper:
|
||||
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
|
||||
self.remote_stub = stub
|
||||
self.options = validate_options(options)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.remote_stub, key)
|
||||
|
||||
def _prepare_client_task(self):
|
||||
task = self.remote_stub._prepare_client_task()
|
||||
set_task_options(task, self.options)
|
||||
return task
|
||||
|
||||
|
||||
class ActorOptionWrapper(OptionWrapper):
|
||||
def remote(self, *args, **kwargs):
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
|
||||
|
||||
|
||||
def set_task_options(task: ray_client_pb2.ClientTask,
|
||||
options: Optional[Dict[str, Any]],
|
||||
field: str = "options") -> None:
|
||||
if options is None:
|
||||
task.ClearField(field)
|
||||
return
|
||||
options_str = json.dumps(options)
|
||||
getattr(task, field).json_options = options_str
|
||||
|
||||
|
||||
def return_refs(ids: List[bytes]
|
||||
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
||||
if len(ids) == 1:
|
||||
return ClientObjectRef(ids[0])
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
return [ClientObjectRef(id) for id in ids]
|
||||
|
||||
|
||||
class DataEncodingSentinel:
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
|
54
python/ray/experimental/client/options.py
Normal file
54
python/ray/experimental/client/options.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
options = {
|
||||
"num_returns": (int, lambda x: x >= 0,
|
||||
"The keyword 'num_returns' only accepts 0 "
|
||||
"or a positive integer"),
|
||||
"num_cpus": (),
|
||||
"num_gpus": (),
|
||||
"resources": (),
|
||||
"accelerator_type": (),
|
||||
"max_calls": (int, lambda x: x >= 0,
|
||||
"The keyword 'max_calls' only accepts 0 "
|
||||
"or a positive integer"),
|
||||
"max_restarts": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_restarts' only accepts -1, 0 "
|
||||
"or a positive integer"),
|
||||
"max_task_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_task_retries' only accepts -1, 0 "
|
||||
"or a positive integer"),
|
||||
"max_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_retries' only accepts 0, -1 "
|
||||
"or a positive integer"),
|
||||
"max_concurrency": (),
|
||||
"name": (),
|
||||
"lifetime": (),
|
||||
"memory": (),
|
||||
"object_store_memory": (),
|
||||
"placement_group": (),
|
||||
"placement_group_bundle_index": (),
|
||||
"placement_group_capture_child_tasks": (),
|
||||
"override_environment_variables": (),
|
||||
}
|
||||
|
||||
|
||||
def validate_options(
|
||||
kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if kwargs_dict is None:
|
||||
return None
|
||||
if len(kwargs_dict) == 0:
|
||||
return None
|
||||
out = {}
|
||||
for k, v in kwargs_dict.items():
|
||||
if k not in options.keys():
|
||||
raise TypeError(f"Invalid option passed to remote(): {k}")
|
||||
validator = options[k]
|
||||
if len(validator) != 0:
|
||||
if not isinstance(v, validator[0]):
|
||||
raise ValueError(validator[2])
|
||||
if not validator[1](v):
|
||||
raise ValueError(validator[2])
|
||||
out[k] = v
|
||||
return out
|
|
@ -4,8 +4,10 @@ import grpc
|
|||
import base64
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
from typing import Optional
|
||||
|
||||
from ray import cloudpickle
|
||||
import ray
|
||||
|
@ -187,9 +189,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None)
|
||||
except Exception:
|
||||
timeout=timeout if timeout != -1 else None,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO(ameer): improve exception messages.
|
||||
logger.error(f"Exception {e}")
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
logger.debug("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
|
@ -206,9 +210,10 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
logger.debug(
|
||||
"schedule: %s %s" % (task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type)))
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
try:
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
|
@ -226,6 +231,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Caught schedule exception {e}")
|
||||
raise e
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
|
@ -236,34 +242,44 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
output = getattr(actor_handle, task.name).remote(*arglist, **kwargs)
|
||||
self.object_refs[task.client_id][output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
method = getattr(actor_handle, task.name)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
method = method.options(**opts)
|
||||
output = method.remote(*arglist, **kwargs)
|
||||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_class = self.lookup_or_register_actor(task.payload_id,
|
||||
task.client_id)
|
||||
remote_class = self.lookup_or_register_actor(
|
||||
task.payload_id, task.client_id,
|
||||
decode_options(task.baseline_options))
|
||||
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
remote_class = remote_class.options(**opts)
|
||||
with current_remote(remote_class):
|
||||
actor = remote_class.remote(*arglist, **kwargs)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_id=actor._actor_id.binary())
|
||||
return_ids=[actor._actor_id.binary()])
|
||||
|
||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_func = self.lookup_or_register_func(task.payload_id,
|
||||
task.client_id)
|
||||
remote_func = self.lookup_or_register_func(
|
||||
task.payload_id, task.client_id,
|
||||
decode_options(task.baseline_options))
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
remote_func = remote_func.options(**opts)
|
||||
with current_remote(remote_func):
|
||||
output = remote_func.remote(*arglist, **kwargs)
|
||||
if output.binary() in self.object_refs[task.client_id]:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[task.client_id][output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _convert_args(self, arg_list, kwarg_map):
|
||||
argout = []
|
||||
|
@ -275,28 +291,50 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
kwargout[k] = convert_from_arg(kwarg_map[k], self)
|
||||
return argout, kwargout
|
||||
|
||||
def lookup_or_register_func(self, id: bytes, client_id: str
|
||||
) -> ray.remote_function.RemoteFunction:
|
||||
def lookup_or_register_func(
|
||||
self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
|
||||
if id not in self.function_refs:
|
||||
funcref = self.object_refs[client_id][id]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to register function that "
|
||||
"isn't a function.")
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
if options is None or len(options) == 0:
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
else:
|
||||
self.function_refs[id] = ray.remote(**options)(func)
|
||||
return self.function_refs[id]
|
||||
|
||||
def lookup_or_register_actor(self, id: bytes, client_id: str):
|
||||
def lookup_or_register_actor(self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]):
|
||||
if id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[client_id][id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
if options is None or len(options) == 0:
|
||||
reg_class = ray.remote(actor_class)
|
||||
else:
|
||||
reg_class = ray.remote(**options)(actor_class)
|
||||
self.registered_actor_classes[id] = reg_class
|
||||
|
||||
return self.registered_actor_classes[id]
|
||||
|
||||
def unify_and_track_outputs(self, output, client_id):
|
||||
if output is None:
|
||||
outputs = []
|
||||
elif isinstance(output, list):
|
||||
outputs = output
|
||||
else:
|
||||
outputs = [output]
|
||||
for out in outputs:
|
||||
if out.binary() in self.object_refs[client_id]:
|
||||
logger.warning(f"Already saw object_ref {out}")
|
||||
self.object_refs[client_id][out.binary()] = out
|
||||
return [out.binary() for out in outputs]
|
||||
|
||||
|
||||
def return_exception_in_context(err, context):
|
||||
if context is not None:
|
||||
|
@ -309,6 +347,15 @@ def encode_exception(exception) -> str:
|
|||
return base64.standard_b64encode(data).decode()
|
||||
|
||||
|
||||
def decode_options(
|
||||
options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
|
||||
if options.json_options == "":
|
||||
return None
|
||||
opts = json.loads(options.json_options)
|
||||
assert isinstance(opts, dict)
|
||||
return opts
|
||||
|
||||
|
||||
def serve(connection_str, test_mode=False):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
|
|
|
@ -56,6 +56,7 @@ class ServerPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=obj_id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ray.actor.ActorHandle):
|
||||
actor_id = obj._actor_id.binary()
|
||||
|
@ -69,6 +70,7 @@ class ServerPickler(cloudpickle.CloudPickler):
|
|||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id.binary(),
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return None
|
||||
|
||||
|
@ -89,13 +91,13 @@ class ClientUnpickler(pickle.Unpickler):
|
|||
elif pid.type == "RemoteFuncSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteFunc":
|
||||
return self.server.lookup_or_register_func(pid.ref_id,
|
||||
pid.client_id)
|
||||
return self.server.lookup_or_register_func(
|
||||
pid.ref_id, pid.client_id, pid.baseline_options)
|
||||
elif pid.type == "RemoteActorSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteActor":
|
||||
return self.server.lookup_or_register_actor(
|
||||
pid.ref_id, pid.client_id)
|
||||
pid.ref_id, pid.client_id, pid.baseline_options)
|
||||
elif pid.type == "RemoteMethod":
|
||||
actor = self.server.actor_refs[pid.ref_id]
|
||||
return getattr(actor, pid.name)
|
||||
|
|
|
@ -21,12 +21,13 @@ import ray.cloudpickle as cloudpickle
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.client_pickler import convert_to_arg
|
||||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.client_pickler import dumps_from_client
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.dataclient import DataClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -80,7 +81,9 @@ class Worker:
|
|||
except grpc.RpcError as e:
|
||||
raise e.details()
|
||||
if not data.valid:
|
||||
raise cloudpickle.loads(data.error)
|
||||
err = cloudpickle.loads(data.error)
|
||||
logger.error(err)
|
||||
raise err
|
||||
return loads_from_server(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
|
@ -98,6 +101,13 @@ class Worker:
|
|||
return out
|
||||
|
||||
def _put(self, val):
|
||||
if isinstance(val, ClientObjectRef):
|
||||
raise TypeError(
|
||||
"Calling 'put' on an ObjectRef is not allowed "
|
||||
"(similarly, returning an ObjectRef from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ObjectRef in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
data = dumps_from_client(val, self._client_id)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.data_client.PutObject(req)
|
||||
|
@ -107,7 +117,8 @@ class Worker:
|
|||
object_refs: List[ClientObjectRef],
|
||||
*,
|
||||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
timeout: float = None,
|
||||
fetch_local: bool = True
|
||||
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||
if not isinstance(object_refs, list):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
|
@ -136,19 +147,22 @@ class Worker:
|
|||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, function_or_class, *args, **kwargs):
|
||||
# TODO(barakmich): Arguments to ray.remote
|
||||
# get captured here.
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
def remote(self, *args, **kwargs):
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return remote_decorator(options=None)(args[0])
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
|
||||
"'memory', 'object_store_memory', 'resources', "
|
||||
"'max_calls', or 'max_restarts', like "
|
||||
"'@ray.remote(num_returns=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
return remote_decorator(options=kwargs)
|
||||
|
||||
def call_remote(self, instance, *args, **kwargs) -> bytes:
|
||||
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg, self._client_id)
|
||||
|
@ -160,10 +174,10 @@ class Worker:
|
|||
try:
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise e.details()
|
||||
raise decode_exception(e.details)
|
||||
if not ticket.valid:
|
||||
raise cloudpickle.loads(ticket.error)
|
||||
return ticket.return_id
|
||||
return ticket.return_ids
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
self.reference_count[id] -= 1
|
||||
|
@ -234,6 +248,20 @@ class Worker:
|
|||
return False
|
||||
|
||||
|
||||
def remote_decorator(options: Optional[Dict[str, Any]]):
|
||||
def decorator(function_or_class) -> ClientStub:
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class, options=options)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class, options=options)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
return id.hex
|
||||
|
|
|
@ -158,6 +158,7 @@ py_test(
|
|||
py_test_module_list(
|
||||
files = [
|
||||
"test_actor.py",
|
||||
"test_advanced.py",
|
||||
"test_basic.py",
|
||||
"test_basic_2.py",
|
||||
],
|
||||
|
|
|
@ -25,7 +25,9 @@ else:
|
|||
import setproctitle # noqa
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="test setup order")
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(),
|
||||
reason="defining early, no ray package injection yet")
|
||||
def test_caching_actors(shutdown_only):
|
||||
# Test defining actors before ray.init() has been called.
|
||||
|
||||
|
@ -564,7 +566,6 @@ def test_actor_static_attributes(ray_start_regular_shared):
|
|||
assert ray.get(t.g.remote()) == 3
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_decorator_args(ray_start_regular_shared):
|
||||
# This is an invalid way of using the actor decorator.
|
||||
with pytest.raises(Exception):
|
||||
|
@ -655,7 +656,7 @@ def test_actor_inheritance(ray_start_regular_shared):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="ray.method unimplemented")
|
||||
def test_multiple_return_values(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
@ -689,7 +690,6 @@ def test_multiple_return_values(ray_start_regular_shared):
|
|||
assert ray.get([id3a, id3b, id3c]) == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_options_num_returns(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
|
|
@ -10,16 +10,22 @@ import time
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
|
||||
from ray.test_utils import client_test_enabled
|
||||
from ray.test_utils import RayTestTimeoutException
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# issue https://github.com/ray-project/ray/issues/7105
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_internal_free(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
|
@ -60,14 +66,14 @@ def test_multiple_waits_and_gets(shutdown_only):
|
|||
return 1
|
||||
|
||||
@ray.remote
|
||||
def g(l):
|
||||
# The argument l should be a list containing one object ref.
|
||||
ray.wait([l[0]])
|
||||
def g(input_list):
|
||||
# The argument input_list should be a list containing one object ref.
|
||||
ray.wait([input_list[0]])
|
||||
|
||||
@ray.remote
|
||||
def h(l):
|
||||
# The argument l should be a list containing one object ref.
|
||||
ray.get(l[0])
|
||||
def h(input_list):
|
||||
# The argument input_list should be a list containing one object ref.
|
||||
ray.get(input_list[0])
|
||||
|
||||
# Make sure that multiple wait requests involving the same object ref
|
||||
# all return.
|
||||
|
@ -80,6 +86,7 @@ def test_multiple_waits_and_gets(shutdown_only):
|
|||
ray.get([h.remote([x]), h.remote([x])])
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_caching_functions_to_run(shutdown_only):
|
||||
# Test that we export functions to run on all workers before the driver
|
||||
# is connected.
|
||||
|
@ -125,6 +132,7 @@ def test_caching_functions_to_run(shutdown_only):
|
|||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_running_function_on_all_workers(ray_start_regular):
|
||||
def f(worker_info):
|
||||
sys.path.append("fake_directory")
|
||||
|
@ -152,6 +160,7 @@ def test_running_function_on_all_workers(ray_start_regular):
|
|||
assert "fake_directory" not in ray.get(get_path2.remote())
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="ray.timeline")
|
||||
def test_profiling_api(ray_start_2_cpus):
|
||||
@ray.remote
|
||||
def f():
|
||||
|
@ -482,6 +491,7 @@ def test_multithreading(ray_start_2_cpus):
|
|||
ray.get(actor.join.remote()) == "ok"
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_wait_makes_object_local(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=0)
|
||||
|
|
|
@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/6662
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
|
||||
def test_ignore_http_proxy(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
os.environ["http_proxy"] = "http://example.com"
|
||||
|
@ -55,14 +55,12 @@ def test_grpc_message_size(shutdown_only):
|
|||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/7287
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_omp_threads_set(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
# Should have been auto set by ray init.
|
||||
assert os.environ["OMP_NUM_THREADS"] == "1"
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_submit_api(shutdown_only):
|
||||
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
|
||||
|
||||
|
@ -121,7 +119,6 @@ def test_submit_api(shutdown_only):
|
|||
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_invalid_arguments(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
|
@ -176,7 +173,6 @@ def test_invalid_arguments(shutdown_only):
|
|||
x = 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_many_fractional_resources(shutdown_only):
|
||||
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
|
||||
|
||||
|
@ -244,7 +240,6 @@ def test_many_fractional_resources(shutdown_only):
|
|||
assert False, "Did not get correct available resources."
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_background_tasks_with_max_calls(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
|
@ -360,8 +355,9 @@ def test_function_descriptor():
|
|||
assert d.get(python_descriptor2) == 123
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_ray_options(shutdown_only):
|
||||
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
|
||||
|
||||
@ray.remote(
|
||||
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
|
||||
def foo():
|
||||
|
@ -370,8 +366,6 @@ def test_ray_options(shutdown_only):
|
|||
time.sleep(0.1)
|
||||
return ray.available_resources()
|
||||
|
||||
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
|
||||
|
||||
without_options = ray.get(foo.remote())
|
||||
with_options = ray.get(
|
||||
foo.options(
|
||||
|
|
|
@ -537,7 +537,6 @@ def test_actor_recursive(ray_start_regular_shared):
|
|||
assert result == [x * 2 for x in range(100)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_actor_concurrent(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Batcher:
|
||||
|
|
|
@ -35,6 +35,18 @@ message Arg {
|
|||
Type type = 4;
|
||||
}
|
||||
|
||||
// A message representing the valid options to modify a task exectution
|
||||
//
|
||||
// TODO(barakmich): In the longer term, if everything were a client,
|
||||
// this message could be the actual standard for which options are
|
||||
// allowed in the API. Today, however, it's a bit flexible and defined in the
|
||||
// Python code. So for now, it's a stand-in message with a json field, but
|
||||
// this is forwards-compatible with deprecating that field and instituting
|
||||
// strongly defined and typed fields, without migrating the original ClientTask.
|
||||
message TaskOptions {
|
||||
string json_options = 1;
|
||||
}
|
||||
|
||||
// Represents one unit of work to be executed by the server.
|
||||
message ClientTask {
|
||||
enum RemoteExecType {
|
||||
|
@ -45,8 +57,8 @@ message ClientTask {
|
|||
}
|
||||
// Which type of work this request represents.
|
||||
RemoteExecType type = 1;
|
||||
// A name parameter, if the payload can be called in more than one way (like a method on
|
||||
// a payload object).
|
||||
// A name parameter, if the payload can be called in more than one way
|
||||
// (like a method on a payload object).
|
||||
string name = 2;
|
||||
// A reference to the payload.
|
||||
bytes payload_id = 3;
|
||||
|
@ -54,16 +66,20 @@ message ClientTask {
|
|||
repeated Arg args = 4;
|
||||
// Keyword parameters to pass to this call.
|
||||
map<string, Arg> kwargs = 5;
|
||||
// The ID of the client namespace associated with the Datapath stream making this
|
||||
// request.
|
||||
// The ID of the client namespace associated with the Datapath stream
|
||||
// making this request.
|
||||
string client_id = 6;
|
||||
// Options for modifying the remote task execution environment.
|
||||
TaskOptions options = 7;
|
||||
// Options passed to create the default remote task excution environment.
|
||||
TaskOptions baseline_options = 8;
|
||||
}
|
||||
|
||||
message ClientTaskTicket {
|
||||
// Was the task successful?
|
||||
bool valid = 1;
|
||||
// A reference to the returned value from the execution.
|
||||
bytes return_id = 2;
|
||||
// A reference to the returned values from the execution.
|
||||
repeated bytes return_ids = 2;
|
||||
// If unsuccessful, an encoding of the error.
|
||||
bytes error = 3;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue