[ray_client] Implement optional arguments to ray.remote() and f.options() (#12985)

This commit is contained in:
Barak Michener 2020-12-20 15:43:48 -08:00 committed by GitHub
parent 11f34f72d8
commit 80f6dd16b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 336 additions and 90 deletions

View file

@ -28,6 +28,7 @@ import sys
from typing import NamedTuple
from typing import Any
from typing import Dict
from typing import Optional
from ray.experimental.client import RayAPIStub
@ -37,6 +38,7 @@ from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import OptionWrapper
from ray.experimental.client.common import SelfReferenceSentinel
import ray.core.generated.ray_client_pb2 as ray_client_pb2
@ -52,7 +54,8 @@ else:
# the data for an exectuion, with no arguments. Combine the two?
PickleStub = NamedTuple("PickleStub",
[("type", str), ("client_id", str), ("ref_id", bytes),
("name", Optional[str])])
("name", Optional[str]),
("baseline_options", Optional[Dict])])
class ClientPickler(cloudpickle.CloudPickler):
@ -67,6 +70,7 @@ class ClientPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientObjectRef):
return PickleStub(
@ -74,6 +78,7 @@ class ClientPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=obj.id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientActorHandle):
return PickleStub(
@ -81,6 +86,7 @@ class ClientPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=obj._actor_id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientRemoteFunc):
# TODO(barakmich): This is going to have trouble with mutually
@ -95,12 +101,14 @@ class ClientPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
return PickleStub(
type="RemoteFunc",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
baseline_options=obj._options,
)
elif isinstance(obj, ClientActorClass):
# TODO(barakmich): Mutual recursion, as above.
@ -112,12 +120,14 @@ class ClientPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
return PickleStub(
type="RemoteActor",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
baseline_options=obj._options,
)
elif isinstance(obj, ClientRemoteMethod):
return PickleStub(
@ -125,7 +135,11 @@ class ClientPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=obj.actor_handle.actor_ref.id,
name=obj.method_name,
baseline_options=None,
)
elif isinstance(obj, OptionWrapper):
raise NotImplementedError(
"Sending a partial option is unimplemented")
return None

View file

@ -1,9 +1,21 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray
from ray.experimental.client.options import validate_options
import json
import threading
from typing import Any
from typing import List
from typing import Dict
from typing import Optional
from typing import Union
class ClientBaseRef:
def __init__(self, id: bytes):
self.id = None
if not isinstance(id, bytes):
raise TypeError("ClientRefs must be created with bytes IDs")
self.id: bytes = id
ray.call_retain(id)
@ -23,7 +35,7 @@ class ClientBaseRef:
return hash(self.id)
def __del__(self):
if ray.is_connected():
if ray.is_connected() and self.id is not None:
ray.call_release(self.id)
@ -52,33 +64,42 @@ class ClientRemoteFunc(ClientStub):
_ref: The ClientObjectRef of the pickled code of the function, _func
"""
def __init__(self, f):
def __init__(self, f, options=None):
self._lock = threading.Lock()
self._func = f
self._name = f.__name__
self._ref = None
self._options = validate_options(options)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
return return_refs(ray.call_remote(self, *args, **kwargs))
def options(self, **kwargs):
return OptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
def _ensure_ref(self):
if self._ref is None:
# While calling ray.put() on our function, if
# our function is recursive, it will attempt to
# encode the ClientRemoteFunc -- itself -- and
# infinitely recurse on _ensure_ref.
#
# So we set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self._func)
with self._lock:
if self._ref is None:
# While calling ray.put() on our function, if
# our function is recursive, it will attempt to
# encode the ClientRemoteFunc -- itself -- and
# infinitely recurse on _ensure_ref.
#
# So we set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self._func)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
self._ensure_ref()
@ -86,6 +107,7 @@ class ClientRemoteFunc(ClientStub):
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.id
set_task_options(task, self._options, "baseline_options")
return task
@ -100,10 +122,11 @@ class ClientActorClass(ClientStub):
_ref: The ClientObjectRef of the pickled `actor_cls`
"""
def __init__(self, actor_cls):
def __init__(self, actor_cls, options=None):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
self._ref = None
self._options = validate_options(options)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
@ -119,8 +142,15 @@ class ClientActorClass(ClientStub):
def remote(self, *args, **kwargs) -> "ClientActorHandle":
# Actually instantiate the actor
ref_id = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref_id), self)
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
def options(self, **kwargs):
return ActorOptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def __repr__(self):
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
@ -136,6 +166,7 @@ class ClientActorClass(ClientStub):
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.id
set_task_options(task, self._options, "baseline_options")
return task
@ -160,7 +191,8 @@ class ClientActorHandle(ClientStub):
self.actor_ref = actor_ref
def __del__(self) -> None:
ray.call_release(self.actor_ref.id)
if ray.is_connected():
ray.call_release(self.actor_ref.id)
@property
def _actor_id(self):
@ -193,12 +225,18 @@ class ClientRemoteMethod(ClientStub):
f"Use {self._name}.remote() instead")
def remote(self, *args, **kwargs):
return ClientObjectRef(ray.call_remote(self, *args, **kwargs))
return return_refs(ray.call_remote(self, *args, **kwargs))
def __repr__(self):
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
self.actor_handle)
def options(self, **kwargs):
return OptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
@ -207,6 +245,49 @@ class ClientRemoteMethod(ClientStub):
return task
class OptionWrapper:
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
self.remote_stub = stub
self.options = validate_options(options)
def remote(self, *args, **kwargs):
return return_refs(ray.call_remote(self, *args, **kwargs))
def __getattr__(self, key):
return getattr(self.remote_stub, key)
def _prepare_client_task(self):
task = self.remote_stub._prepare_client_task()
set_task_options(task, self.options)
return task
class ActorOptionWrapper(OptionWrapper):
def remote(self, *args, **kwargs):
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]), self)
def set_task_options(task: ray_client_pb2.ClientTask,
options: Optional[Dict[str, Any]],
field: str = "options") -> None:
if options is None:
task.ClearField(field)
return
options_str = json.dumps(options)
getattr(task, field).json_options = options_str
def return_refs(ids: List[bytes]
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
if len(ids) == 1:
return ClientObjectRef(ids[0])
if len(ids) == 0:
return None
return [ClientObjectRef(id) for id in ids]
class DataEncodingSentinel:
def __repr__(self) -> str:
return self.__class__.__name__

View file

@ -0,0 +1,54 @@
from typing import Any
from typing import Dict
from typing import Optional
options = {
"num_returns": (int, lambda x: x >= 0,
"The keyword 'num_returns' only accepts 0 "
"or a positive integer"),
"num_cpus": (),
"num_gpus": (),
"resources": (),
"accelerator_type": (),
"max_calls": (int, lambda x: x >= 0,
"The keyword 'max_calls' only accepts 0 "
"or a positive integer"),
"max_restarts": (int, lambda x: x >= -1,
"The keyword 'max_restarts' only accepts -1, 0 "
"or a positive integer"),
"max_task_retries": (int, lambda x: x >= -1,
"The keyword 'max_task_retries' only accepts -1, 0 "
"or a positive integer"),
"max_retries": (int, lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 "
"or a positive integer"),
"max_concurrency": (),
"name": (),
"lifetime": (),
"memory": (),
"object_store_memory": (),
"placement_group": (),
"placement_group_bundle_index": (),
"placement_group_capture_child_tasks": (),
"override_environment_variables": (),
}
def validate_options(
kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if kwargs_dict is None:
return None
if len(kwargs_dict) == 0:
return None
out = {}
for k, v in kwargs_dict.items():
if k not in options.keys():
raise TypeError(f"Invalid option passed to remote(): {k}")
validator = options[k]
if len(validator) != 0:
if not isinstance(v, validator[0]):
raise ValueError(validator[2])
if not validator[1](v):
raise ValueError(validator[2])
out[k] = v
return out

View file

@ -4,8 +4,10 @@ import grpc
import base64
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import Set
from typing import Optional
from ray import cloudpickle
import ray
@ -187,9 +189,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ready_object_refs, remaining_object_refs = ray.wait(
object_refs,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None)
except Exception:
timeout=timeout if timeout != -1 else None,
)
except Exception as e:
# TODO(ameer): improve exception messages.
logger.error(f"Exception {e}")
return ray_client_pb2.WaitResponse(valid=False)
logger.debug("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
@ -206,9 +210,10 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
logger.debug(
"schedule: %s %s" % (task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type)))
with stash_api_for_tests(self._test_mode):
try:
if task.type == ray_client_pb2.ClientTask.FUNCTION:
@ -226,6 +231,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
return result
except Exception as e:
logger.error(f"Caught schedule exception {e}")
raise e
return ray_client_pb2.ClientTaskTicket(
valid=False, error=cloudpickle.dumps(e))
@ -236,34 +242,44 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist, kwargs = self._convert_args(task.args, task.kwargs)
output = getattr(actor_handle, task.name).remote(*arglist, **kwargs)
self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
method = getattr(actor_handle, task.name)
opts = decode_options(task.options)
if opts is not None:
method = method.options(**opts)
output = method.remote(*arglist, **kwargs)
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_class = self.lookup_or_register_actor(task.payload_id,
task.client_id)
remote_class = self.lookup_or_register_actor(
task.payload_id, task.client_id,
decode_options(task.baseline_options))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_class = remote_class.options(**opts)
with current_remote(remote_class):
actor = remote_class.remote(*arglist, **kwargs)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_id=actor._actor_id.binary())
return_ids=[actor._actor_id.binary()])
def _schedule_function(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_func = self.lookup_or_register_func(task.payload_id,
task.client_id)
remote_func = self.lookup_or_register_func(
task.payload_id, task.client_id,
decode_options(task.baseline_options))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_func = remote_func.options(**opts)
with current_remote(remote_func):
output = remote_func.remote(*arglist, **kwargs)
if output.binary() in self.object_refs[task.client_id]:
raise Exception("already found it")
self.object_refs[task.client_id][output.binary()] = output
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _convert_args(self, arg_list, kwarg_map):
argout = []
@ -275,28 +291,50 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout
def lookup_or_register_func(self, id: bytes, client_id: str
) -> ray.remote_function.RemoteFunction:
def lookup_or_register_func(
self, id: bytes, client_id: str,
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
if id not in self.function_refs:
funcref = self.object_refs[client_id][id]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to register function that "
"isn't a function.")
self.function_refs[id] = ray.remote(func)
if options is None or len(options) == 0:
self.function_refs[id] = ray.remote(func)
else:
self.function_refs[id] = ray.remote(**options)(func)
return self.function_refs[id]
def lookup_or_register_actor(self, id: bytes, client_id: str):
def lookup_or_register_actor(self, id: bytes, client_id: str,
options: Optional[Dict]):
if id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[client_id][id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
if options is None or len(options) == 0:
reg_class = ray.remote(actor_class)
else:
reg_class = ray.remote(**options)(actor_class)
self.registered_actor_classes[id] = reg_class
return self.registered_actor_classes[id]
def unify_and_track_outputs(self, output, client_id):
if output is None:
outputs = []
elif isinstance(output, list):
outputs = output
else:
outputs = [output]
for out in outputs:
if out.binary() in self.object_refs[client_id]:
logger.warning(f"Already saw object_ref {out}")
self.object_refs[client_id][out.binary()] = out
return [out.binary() for out in outputs]
def return_exception_in_context(err, context):
if context is not None:
@ -309,6 +347,15 @@ def encode_exception(exception) -> str:
return base64.standard_b64encode(data).decode()
def decode_options(
options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
if options.json_options == "":
return None
opts = json.loads(options.json_options)
assert isinstance(opts, dict)
return opts
def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)

View file

@ -56,6 +56,7 @@ class ServerPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=obj_id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ray.actor.ActorHandle):
actor_id = obj._actor_id.binary()
@ -69,6 +70,7 @@ class ServerPickler(cloudpickle.CloudPickler):
client_id=self.client_id,
ref_id=obj._actor_id.binary(),
name=None,
baseline_options=None,
)
return None
@ -89,13 +91,13 @@ class ClientUnpickler(pickle.Unpickler):
elif pid.type == "RemoteFuncSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteFunc":
return self.server.lookup_or_register_func(pid.ref_id,
pid.client_id)
return self.server.lookup_or_register_func(
pid.ref_id, pid.client_id, pid.baseline_options)
elif pid.type == "RemoteActorSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteActor":
return self.server.lookup_or_register_actor(
pid.ref_id, pid.client_id)
pid.ref_id, pid.client_id, pid.baseline_options)
elif pid.type == "RemoteMethod":
actor = self.server.actor_refs[pid.ref_id]
return getattr(actor, pid.name)

View file

@ -21,12 +21,13 @@ import ray.cloudpickle as cloudpickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.client_pickler import convert_to_arg
from ray.experimental.client.client_pickler import loads_from_server
from ray.experimental.client.client_pickler import dumps_from_client
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.client_pickler import loads_from_server
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientStub
from ray.experimental.client.dataclient import DataClient
logger = logging.getLogger(__name__)
@ -80,7 +81,9 @@ class Worker:
except grpc.RpcError as e:
raise e.details()
if not data.valid:
raise cloudpickle.loads(data.error)
err = cloudpickle.loads(data.error)
logger.error(err)
raise err
return loads_from_server(data.data)
def put(self, vals):
@ -98,6 +101,13 @@ class Worker:
return out
def _put(self, val):
if isinstance(val, ClientObjectRef):
raise TypeError(
"Calling 'put' on an ObjectRef is not allowed "
"(similarly, returning an ObjectRef from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ObjectRef in a list and "
"call 'put' on it (or return it).")
data = dumps_from_client(val, self._client_id)
req = ray_client_pb2.PutRequest(data=data)
resp = self.data_client.PutObject(req)
@ -107,7 +117,8 @@ class Worker:
object_refs: List[ClientObjectRef],
*,
num_returns: int = 1,
timeout: float = None
timeout: float = None,
fetch_local: bool = True
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
if not isinstance(object_refs, list):
raise TypeError("wait() expected a list of ClientObjectRef, "
@ -136,19 +147,22 @@ class Worker:
return (client_ready_object_ids, client_remaining_object_ids)
def remote(self, function_or_class, *args, **kwargs):
# TODO(barakmich): Arguments to ray.remote
# get captured here.
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def remote(self, *args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
return remote_decorator(options=None)(args[0])
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
"'@ray.remote(num_returns=2, "
"resources={\"CustomResource\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
return remote_decorator(options=kwargs)
def call_remote(self, instance, *args, **kwargs) -> bytes:
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg, self._client_id)
@ -160,10 +174,10 @@ class Worker:
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
raise e.details()
raise decode_exception(e.details)
if not ticket.valid:
raise cloudpickle.loads(ticket.error)
return ticket.return_id
return ticket.return_ids
def call_release(self, id: bytes) -> None:
self.reference_count[id] -= 1
@ -234,6 +248,20 @@ class Worker:
return False
def remote_decorator(options: Optional[Dict[str, Any]]):
def decorator(function_or_class) -> ClientStub:
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class, options=options)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class, options=options)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
return decorator
def make_client_id() -> str:
id = uuid.uuid4()
return id.hex

View file

@ -158,6 +158,7 @@ py_test(
py_test_module_list(
files = [
"test_actor.py",
"test_advanced.py",
"test_basic.py",
"test_basic_2.py",
],

View file

@ -25,7 +25,9 @@ else:
import setproctitle # noqa
@pytest.mark.skipif(client_test_enabled(), reason="test setup order")
@pytest.mark.skipif(
client_test_enabled(),
reason="defining early, no ray package injection yet")
def test_caching_actors(shutdown_only):
# Test defining actors before ray.init() has been called.
@ -564,7 +566,6 @@ def test_actor_static_attributes(ray_start_regular_shared):
assert ray.get(t.g.remote()) == 3
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_decorator_args(ray_start_regular_shared):
# This is an invalid way of using the actor decorator.
with pytest.raises(Exception):
@ -655,7 +656,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
@pytest.mark.skipif(client_test_enabled(), reason="ray.method unimplemented")
def test_multiple_return_values(ray_start_regular_shared):
@ray.remote
class Foo:
@ -689,7 +690,6 @@ def test_multiple_return_values(ray_start_regular_shared):
assert ray.get([id3a, id3b, id3c]) == [1, 2, 3]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_options_num_returns(ray_start_regular_shared):
@ray.remote
class Foo:

View file

@ -10,16 +10,22 @@ import time
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
from ray.test_utils import client_test_enabled
from ray.test_utils import RayTestTimeoutException
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__)
# issue https://github.com/ray-project/ray/issues/7105
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_internal_free(shutdown_only):
ray.init(num_cpus=1)
@ -60,14 +66,14 @@ def test_multiple_waits_and_gets(shutdown_only):
return 1
@ray.remote
def g(l):
# The argument l should be a list containing one object ref.
ray.wait([l[0]])
def g(input_list):
# The argument input_list should be a list containing one object ref.
ray.wait([input_list[0]])
@ray.remote
def h(l):
# The argument l should be a list containing one object ref.
ray.get(l[0])
def h(input_list):
# The argument input_list should be a list containing one object ref.
ray.get(input_list[0])
# Make sure that multiple wait requests involving the same object ref
# all return.
@ -80,6 +86,7 @@ def test_multiple_waits_and_gets(shutdown_only):
ray.get([h.remote([x]), h.remote([x])])
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_caching_functions_to_run(shutdown_only):
# Test that we export functions to run on all workers before the driver
# is connected.
@ -125,6 +132,7 @@ def test_caching_functions_to_run(shutdown_only):
ray.worker.global_worker.run_function_on_all_workers(f)
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_running_function_on_all_workers(ray_start_regular):
def f(worker_info):
sys.path.append("fake_directory")
@ -152,6 +160,7 @@ def test_running_function_on_all_workers(ray_start_regular):
assert "fake_directory" not in ray.get(get_path2.remote())
@pytest.mark.skipif(client_test_enabled(), reason="ray.timeline")
def test_profiling_api(ray_start_2_cpus):
@ray.remote
def f():
@ -482,6 +491,7 @@ def test_multithreading(ray_start_2_cpus):
ray.get(actor.join.remote()) == "ok"
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_wait_makes_object_local(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)

View file

@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
# https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
def test_ignore_http_proxy(shutdown_only):
ray.init(num_cpus=1)
os.environ["http_proxy"] = "http://example.com"
@ -55,14 +55,12 @@ def test_grpc_message_size(shutdown_only):
# https://github.com/ray-project/ray/issues/7287
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_omp_threads_set(shutdown_only):
ray.init(num_cpus=1)
# Should have been auto set by ray init.
assert os.environ["OMP_NUM_THREADS"] == "1"
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_submit_api(shutdown_only):
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
@ -121,7 +119,6 @@ def test_submit_api(shutdown_only):
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_invalid_arguments(shutdown_only):
ray.init(num_cpus=2)
@ -176,7 +173,6 @@ def test_invalid_arguments(shutdown_only):
x = 1
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_many_fractional_resources(shutdown_only):
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
@ -244,7 +240,6 @@ def test_many_fractional_resources(shutdown_only):
assert False, "Did not get correct available resources."
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_background_tasks_with_max_calls(shutdown_only):
ray.init(num_cpus=2)
@ -360,8 +355,9 @@ def test_function_descriptor():
assert d.get(python_descriptor2) == 123
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_ray_options(shutdown_only):
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
@ray.remote(
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
def foo():
@ -370,8 +366,6 @@ def test_ray_options(shutdown_only):
time.sleep(0.1)
return ray.available_resources()
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
without_options = ray.get(foo.remote())
with_options = ray.get(
foo.options(

View file

@ -537,7 +537,6 @@ def test_actor_recursive(ray_start_regular_shared):
assert result == [x * 2 for x in range(100)]
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
def test_actor_concurrent(ray_start_regular_shared):
@ray.remote
class Batcher:

View file

@ -35,6 +35,18 @@ message Arg {
Type type = 4;
}
// A message representing the valid options to modify a task exectution
//
// TODO(barakmich): In the longer term, if everything were a client,
// this message could be the actual standard for which options are
// allowed in the API. Today, however, it's a bit flexible and defined in the
// Python code. So for now, it's a stand-in message with a json field, but
// this is forwards-compatible with deprecating that field and instituting
// strongly defined and typed fields, without migrating the original ClientTask.
message TaskOptions {
string json_options = 1;
}
// Represents one unit of work to be executed by the server.
message ClientTask {
enum RemoteExecType {
@ -45,8 +57,8 @@ message ClientTask {
}
// Which type of work this request represents.
RemoteExecType type = 1;
// A name parameter, if the payload can be called in more than one way (like a method on
// a payload object).
// A name parameter, if the payload can be called in more than one way
// (like a method on a payload object).
string name = 2;
// A reference to the payload.
bytes payload_id = 3;
@ -54,16 +66,20 @@ message ClientTask {
repeated Arg args = 4;
// Keyword parameters to pass to this call.
map<string, Arg> kwargs = 5;
// The ID of the client namespace associated with the Datapath stream making this
// request.
// The ID of the client namespace associated with the Datapath stream
// making this request.
string client_id = 6;
// Options for modifying the remote task execution environment.
TaskOptions options = 7;
// Options passed to create the default remote task excution environment.
TaskOptions baseline_options = 8;
}
message ClientTaskTicket {
// Was the task successful?
bool valid = 1;
// A reference to the returned value from the execution.
bytes return_id = 2;
// A reference to the returned values from the execution.
repeated bytes return_ids = 2;
// If unsuccessful, an encoding of the error.
bytes error = 3;
}