mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[ray_client] Integrate with test_basic, test_basic_2 and test_actor (#12964)
This commit is contained in:
parent
bf6577c8f4
commit
7ab9164f1b
21 changed files with 375 additions and 174 deletions
|
@ -46,6 +46,7 @@ matrix:
|
|||
script:
|
||||
# bazel python tests for medium size tests. Used for parallelization.
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_TEST_CLIENT_MODE=1 python/ray/tests/...; fi
|
||||
|
||||
- os: linux
|
||||
env:
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
# py_test_module_list creates a py_test target for each
|
||||
# Python file in `files`
|
||||
def py_test_module_list(files, size, deps, extra_srcs, **kwargs):
|
||||
def py_test_module_list(files, size, deps, extra_srcs, name_suffix="", **kwargs):
|
||||
for file in files:
|
||||
# remove .py
|
||||
name = file[:-3]
|
||||
name = file[:-3] + name_suffix
|
||||
main = file
|
||||
native.py_test(
|
||||
name = name,
|
||||
size = size,
|
||||
main = file,
|
||||
srcs = extra_srcs + [file],
|
||||
**kwargs
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional, List, Tuple
|
|||
from contextlib import contextmanager
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,9 +44,11 @@ def stash_api_for_tests(in_test: bool):
|
|||
is_server = _is_server
|
||||
if in_test:
|
||||
_is_server = True
|
||||
yield _server_api
|
||||
if in_test:
|
||||
_is_server = is_server
|
||||
try:
|
||||
yield _server_api
|
||||
finally:
|
||||
if in_test:
|
||||
_is_server = is_server
|
||||
|
||||
|
||||
def _set_client_api(val: Optional[APIImpl]):
|
||||
|
@ -77,18 +80,7 @@ def reset_api():
|
|||
|
||||
def _get_client_api() -> APIImpl:
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
api = None
|
||||
if _is_server:
|
||||
api = _server_api
|
||||
else:
|
||||
api = _client_api
|
||||
if api is None:
|
||||
# We're inside a raylet worker
|
||||
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
||||
return CoreRayAPI()
|
||||
return api
|
||||
return _client_api
|
||||
|
||||
|
||||
def _get_server_instance():
|
||||
|
@ -124,9 +116,33 @@ class RayAPIStub:
|
|||
global _client_api
|
||||
return _client_api is not None
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
if _is_client_test_env():
|
||||
global _test_server
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
_test_server, address_info = ray_client_server.init_and_serve(
|
||||
"localhost:50051", test_mode=True, *args, **kwargs)
|
||||
self.connect("localhost:50051")
|
||||
return address_info
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please call ray.connect() in client mode")
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
||||
_test_server = None
|
||||
|
||||
|
||||
def _stop_test_server(*args):
|
||||
global _test_server
|
||||
_test_server.stop(*args)
|
||||
|
||||
|
||||
def _is_client_test_env() -> bool:
|
||||
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"
|
||||
|
||||
|
||||
# Someday we might add methods in this module so that someone who
|
||||
# tries to `import ray_client as ray` -- as a module, instead of
|
||||
# `from ray_client import ray` -- as the API stub
|
||||
|
|
|
@ -28,11 +28,15 @@ import sys
|
|||
|
||||
from typing import NamedTuple
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from ray.experimental.client import RayAPIStub
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import SelfReferenceSentinel
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
|
@ -44,8 +48,11 @@ if sys.version_info < (3, 8):
|
|||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str),
|
||||
("ref_id", bytes)])
|
||||
# NOTE(barakmich): These PickleStubs are really close to
|
||||
# the data for an exectuion, with no arguments. Combine the two?
|
||||
PickleStub = NamedTuple("PickleStub",
|
||||
[("type", str), ("client_id", str), ("ref_id", bytes),
|
||||
("name", Optional[str])])
|
||||
|
||||
|
||||
class ClientPickler(cloudpickle.CloudPickler):
|
||||
|
@ -54,17 +61,26 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
self.client_id = client_id
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, ClientObjectRef):
|
||||
if isinstance(obj, RayAPIStub):
|
||||
return PickleStub(
|
||||
type="Ray",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
)
|
||||
elif isinstance(obj, ClientObjectRef):
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.id,
|
||||
name=None,
|
||||
)
|
||||
elif isinstance(obj, ClientActorHandle):
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id,
|
||||
name=None,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteFunc):
|
||||
# TODO(barakmich): This is going to have trouble with mutually
|
||||
|
@ -77,11 +93,39 @@ class ClientPickler(cloudpickle.CloudPickler):
|
|||
return PickleStub(
|
||||
type="RemoteFuncSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"")
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteFunc",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id)
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
)
|
||||
elif isinstance(obj, ClientActorClass):
|
||||
# TODO(barakmich): Mutual recursion, as above.
|
||||
if obj._ref is None:
|
||||
obj._ensure_ref()
|
||||
if type(obj._ref) == SelfReferenceSentinel:
|
||||
return PickleStub(
|
||||
type="RemoteActorSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteActor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteMethod):
|
||||
return PickleStub(
|
||||
type="RemoteMethod",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.actor_handle.actor_ref.id,
|
||||
name=obj.method_name,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
|
@ -8,17 +7,20 @@ class ClientBaseRef:
|
|||
self.id: bytes = id
|
||||
ray.call_retain(id)
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
type(self).__name__,
|
||||
self.id.hex(),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __del__(self):
|
||||
if ray.is_connected():
|
||||
|
@ -107,18 +109,13 @@ class ClientActorClass(ClientStub):
|
|||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_cls": self.actor_cls,
|
||||
"_name": self._name,
|
||||
"_ref": self._ref,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_cls = state["actor_cls"]
|
||||
self._name = state["_name"]
|
||||
self._ref = state["_ref"]
|
||||
def _ensure_ref(self):
|
||||
if self._ref is None:
|
||||
# As before, set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
|
||||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||
# Actually instantiate the actor
|
||||
|
@ -126,7 +123,7 @@ class ClientActorClass(ClientStub):
|
|||
return ClientActorHandle(ClientActorRef(ref_id), self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self.__dict__:
|
||||
|
@ -134,8 +131,7 @@ class ClientActorClass(ClientStub):
|
|||
raise NotImplementedError("static methods")
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
self._ensure_ref()
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
|
|
|
@ -53,24 +53,33 @@ class DataClient:
|
|||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=(("client_id", self._client_id), ))
|
||||
for response in resp_stream:
|
||||
if response.req_id == 0:
|
||||
# This is not being waited for.
|
||||
logger.debug(f"Got unawaited response {response}")
|
||||
continue
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
try:
|
||||
for response in resp_stream:
|
||||
if response.req_id == 0:
|
||||
# This is not being waited for.
|
||||
logger.debug(f"Got unawaited response {response}")
|
||||
continue
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED == e.code():
|
||||
# Gracefully shutting down
|
||||
logger.info("Cancelling data channel")
|
||||
else:
|
||||
logger.error(
|
||||
f"Got Error from rpc channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def close(self, close_channel: bool = False) -> None:
|
||||
if self.request_queue is not None:
|
||||
self.request_queue.put(None)
|
||||
self.request_queue = None
|
||||
if close_channel:
|
||||
self.channel.close()
|
||||
if self.data_thread is not None:
|
||||
self.data_thread.join()
|
||||
self.data_thread = None
|
||||
if close_channel:
|
||||
self.channel.close()
|
||||
|
||||
def _blocking_send(self, req: ray_client_pb2.DataRequest
|
||||
) -> ray_client_pb2.DataResponse:
|
||||
|
|
|
@ -79,8 +79,3 @@ class RayServerAPI(CoreRayAPI):
|
|||
|
||||
def __init__(self, server_instance):
|
||||
self.server = server_instance
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes:
|
||||
task = instance._prepare_client_task()
|
||||
ticket = self.server.Schedule(task, prepared_args=args)
|
||||
return ticket.return_id
|
||||
|
|
|
@ -21,7 +21,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server
|
|||
from ray.experimental.client.server.server_pickler import loads_from_client
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
from ray.experimental.client.server.dataservicer import DataServicer
|
||||
from ray.experimental.client.server.server_stubs import current_func
|
||||
from ray.experimental.client.server.server_stubs import current_remote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -205,82 +205,75 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context, prepared_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
try:
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
result = self._schedule_function(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
result = self._schedule_actor(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
result = self._schedule_method(task, context)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type))
|
||||
result.valid = True
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Caught schedule exception {e}")
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
def _schedule_method(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
actor_handle = self.actor_refs.get(task.payload_id)
|
||||
if actor_handle is None:
|
||||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist = self._convert_args(task.args, prepared_args)
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
output = getattr(actor_handle, task.name).remote(*arglist, **kwargs)
|
||||
self.object_refs[task.client_id][output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
def _schedule_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
if task.payload_id not in self.registered_actor_classes:
|
||||
actor_class_ref = \
|
||||
self.object_refs[task.client_id][task.payload_id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[task.payload_id] = reg_class
|
||||
remote_class = self.registered_actor_classes[task.payload_id]
|
||||
arglist = self._convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_class = self.lookup_or_register_actor(task.payload_id,
|
||||
task.client_id)
|
||||
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
with current_remote(remote_class):
|
||||
actor = remote_class.remote(*arglist, **kwargs)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_id=actor._actor_id.binary())
|
||||
|
||||
def _schedule_function(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_func = self.lookup_or_register_func(task.payload_id,
|
||||
task.client_id)
|
||||
arglist = self._convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with current_func(remote_func):
|
||||
output = remote_func.remote(*arglist)
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
with current_remote(remote_func):
|
||||
output = remote_func.remote(*arglist, **kwargs)
|
||||
if output.binary() in self.object_refs[task.client_id]:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[task.client_id][output.binary()] = output
|
||||
return ray_client_pb2.ClientTaskTicket(return_id=output.binary())
|
||||
|
||||
def _convert_args(self, arg_list, prepared_args=None):
|
||||
if prepared_args is not None:
|
||||
return prepared_args
|
||||
out = []
|
||||
def _convert_args(self, arg_list, kwarg_map):
|
||||
argout = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg, self)
|
||||
out.append(t)
|
||||
return out
|
||||
argout.append(t)
|
||||
kwargout = {}
|
||||
for k in kwarg_map:
|
||||
kwargout[k] = convert_from_arg(kwarg_map[k], self)
|
||||
return argout, kwargout
|
||||
|
||||
def lookup_or_register_func(self, id: bytes, client_id: str
|
||||
) -> ray.remote_function.RemoteFunction:
|
||||
|
@ -293,6 +286,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
self.function_refs[id] = ray.remote(func)
|
||||
return self.function_refs[id]
|
||||
|
||||
def lookup_or_register_actor(self, id: bytes, client_id: str):
|
||||
if id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[client_id][id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[id] = reg_class
|
||||
return self.registered_actor_classes[id]
|
||||
|
||||
|
||||
def return_exception_in_context(err, context):
|
||||
if context is not None:
|
||||
|
@ -319,6 +323,12 @@ def serve(connection_str, test_mode=False):
|
|||
return server
|
||||
|
||||
|
||||
def init_and_serve(connection_str, test_mode=False, *args, **kwargs):
|
||||
info = ray.init(*args, **kwargs)
|
||||
server = serve(connection_str, test_mode)
|
||||
return (server, info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level="INFO")
|
||||
# TODO(barakmich): Perhaps wrap ray init
|
||||
|
|
|
@ -21,7 +21,8 @@ from typing import Any
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from ray.experimental.client.client_pickler import PickleStub
|
||||
from ray.experimental.client.server.server_stubs import ServerFunctionSentinel
|
||||
from ray.experimental.client.server.server_stubs import (
|
||||
ServerSelfReferenceSentinel)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
|
@ -54,6 +55,7 @@ class ServerPickler(cloudpickle.CloudPickler):
|
|||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj_id,
|
||||
name=None,
|
||||
)
|
||||
elif isinstance(obj, ray.actor.ActorHandle):
|
||||
actor_id = obj._actor_id.binary()
|
||||
|
@ -66,6 +68,7 @@ class ServerPickler(cloudpickle.CloudPickler):
|
|||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id.binary(),
|
||||
name=None,
|
||||
)
|
||||
return None
|
||||
|
||||
|
@ -77,15 +80,25 @@ class ClientUnpickler(pickle.Unpickler):
|
|||
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Object":
|
||||
if pid.type == "Ray":
|
||||
return ray
|
||||
elif pid.type == "Object":
|
||||
return self.server.object_refs[pid.client_id][pid.ref_id]
|
||||
elif pid.type == "Actor":
|
||||
return self.server.actor_refs[pid.ref_id]
|
||||
elif pid.type == "RemoteFuncSelfReference":
|
||||
return ServerFunctionSentinel()
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteFunc":
|
||||
return self.server.lookup_or_register_func(pid.ref_id,
|
||||
pid.client_id)
|
||||
elif pid.type == "RemoteActorSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteActor":
|
||||
return self.server.lookup_or_register_actor(
|
||||
pid.ref_id, pid.client_id)
|
||||
elif pid.type == "RemoteMethod":
|
||||
actor = self.server.actor_refs[pid.ref_id]
|
||||
return getattr(actor, pid.name)
|
||||
else:
|
||||
raise NotImplementedError("Uncovered client data type")
|
||||
|
||||
|
|
|
@ -1,28 +1,28 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
_current_remote_func = None
|
||||
_current_remote_obj = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def current_func(f):
|
||||
global _current_remote_func
|
||||
remote_func = _current_remote_func
|
||||
_current_remote_func = f
|
||||
def current_remote(r):
|
||||
global _current_remote_obj
|
||||
remote = _current_remote_obj
|
||||
_current_remote_obj = r
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_remote_func = remote_func
|
||||
_current_remote_obj = remote
|
||||
|
||||
|
||||
class ServerFunctionSentinel:
|
||||
class ServerSelfReferenceSentinel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __reduce__(self):
|
||||
global _current_remote_func
|
||||
if _current_remote_func is None:
|
||||
return (ServerFunctionSentinel, tuple())
|
||||
return (identity, (_current_remote_func, ))
|
||||
global _current_remote_obj
|
||||
if _current_remote_obj is None:
|
||||
return (ServerSelfReferenceSentinel, tuple())
|
||||
return (identity, (_current_remote_obj, ))
|
||||
|
||||
|
||||
def identity(x):
|
||||
|
|
|
@ -109,9 +109,13 @@ class Worker:
|
|||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||
assert isinstance(object_refs, list)
|
||||
if not isinstance(object_refs, list):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
f"got {type(object_refs)}")
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
if not isinstance(ref, ClientObjectRef):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
f"got list containing {type(ref)}")
|
||||
data = {
|
||||
"object_ids": [object_ref.id for object_ref in object_refs],
|
||||
"num_returns": num_returns,
|
||||
|
@ -149,9 +153,16 @@ class Worker:
|
|||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg, self._client_id)
|
||||
task.args.append(pb_arg)
|
||||
for k, v in kwargs.items():
|
||||
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
|
||||
task.client_id = self._client_id
|
||||
logger.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
try:
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise e.details()
|
||||
if not ticket.valid:
|
||||
raise cloudpickle.loads(ticket.error)
|
||||
return ticket.return_id
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
|
@ -171,10 +182,9 @@ class Worker:
|
|||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self.data_client.close()
|
||||
self.data_client.close(close_channel=True)
|
||||
self.server = None
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
|
|
|
@ -443,3 +443,7 @@ def format_web_url(url):
|
|||
|
||||
def new_scheduler_enabled():
|
||||
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1"
|
||||
|
||||
|
||||
def client_test_enabled() -> bool:
|
||||
return os.environ.get("RAY_TEST_CLIENT_MODE") == "1"
|
||||
|
|
|
@ -153,3 +153,20 @@ py_test(
|
|||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test_module_list(
|
||||
files = [
|
||||
"test_actor.py",
|
||||
"test_basic.py",
|
||||
"test_basic_2.py",
|
||||
],
|
||||
size = "medium",
|
||||
extra_srcs = SRCS,
|
||||
name_suffix = "_client_mode",
|
||||
# TODO(barakmich): py_test will support env in Bazel 4.0.0...
|
||||
# Until then, we can use tags.
|
||||
#env = {"RAY_TEST_CLIENT_MODE": "true"},
|
||||
tags = ["exclusive", "client_tests"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
|
20
python/ray/tests/client_test_utils.py
Normal file
20
python/ray/tests/client_test_utils.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import asyncio
|
||||
|
||||
|
||||
def create_remote_signal_actor(ray):
|
||||
# TODO(barakmich): num_cpus=0
|
||||
@ray.remote
|
||||
class SignalActor:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self, clear=False):
|
||||
self.ready_event.set()
|
||||
if clear:
|
||||
self.ready_event.clear()
|
||||
|
||||
async def wait(self, should_wait=True):
|
||||
if should_wait:
|
||||
await self.ready_event.wait()
|
||||
|
||||
return SignalActor
|
|
@ -9,12 +9,18 @@ import subprocess
|
|||
import ray
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.test_utils import init_error_pubsub
|
||||
from ray.test_utils import client_test_enabled
|
||||
import ray.experimental.client as ray_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shutdown_only():
|
||||
yield None
|
||||
# The code after the yield will run as teardown code.
|
||||
if client_test_enabled():
|
||||
ray_client.ray.disconnect()
|
||||
ray_client._stop_test_server(1)
|
||||
ray_client.reset_api()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
|
@ -43,9 +49,17 @@ def _ray_start(**kwargs):
|
|||
init_kwargs = get_default_fixture_ray_kwargs()
|
||||
init_kwargs.update(kwargs)
|
||||
# Start the Ray processes.
|
||||
address_info = ray.init(**init_kwargs)
|
||||
if client_test_enabled():
|
||||
address_info = ray_client.ray.init(**init_kwargs)
|
||||
else:
|
||||
address_info = ray.init(**init_kwargs)
|
||||
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
if client_test_enabled():
|
||||
ray_client.ray.disconnect()
|
||||
ray_client._stop_test_server(1)
|
||||
ray_client.reset_api()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
|
@ -130,9 +144,16 @@ def _ray_start_cluster(**kwargs):
|
|||
# We assume driver will connect to the head (first node),
|
||||
# so ray init will be invoked if do_init is true
|
||||
if len(remote_nodes) == 1 and do_init:
|
||||
ray.init(address=cluster.address)
|
||||
if client_test_enabled():
|
||||
ray_client.ray.init(address=cluster.address)
|
||||
else:
|
||||
ray.init(address=cluster.address)
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
if client_test_enabled():
|
||||
ray_client.ray.disconnect()
|
||||
ray_client._stop_test_server(1)
|
||||
ray_client.reset_api()
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
|
|
@ -11,15 +11,21 @@ import sys
|
|||
import tempfile
|
||||
import datetime
|
||||
|
||||
import ray
|
||||
import ray.test_utils
|
||||
import ray.cluster_utils
|
||||
from ray.test_utils import client_test_enabled
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.test_utils import wait_for_pid_to_exit
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
# NOTE: We have to import setproctitle after ray because we bundle setproctitle
|
||||
# with ray.
|
||||
import setproctitle
|
||||
import setproctitle # noqa
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="test setup order")
|
||||
def test_caching_actors(shutdown_only):
|
||||
# Test defining actors before ray.init() has been called.
|
||||
|
||||
|
@ -238,6 +244,7 @@ def test_actor_import_counter(ray_start_10_cpus):
|
|||
assert ray.get(g.remote()) == num_remote_functions - 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_actor_method_metadata_cache(ray_start_regular):
|
||||
class Actor(object):
|
||||
pass
|
||||
|
@ -257,6 +264,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
|
|||
assert [id(x) for x in list(cache.items())[0]] == cached_data_id
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_actor_class_name(ray_start_regular):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
@ -556,6 +564,7 @@ def test_actor_static_attributes(ray_start_regular_shared):
|
|||
assert ray.get(t.g.remote()) == 3
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_decorator_args(ray_start_regular_shared):
|
||||
# This is an invalid way of using the actor decorator.
|
||||
with pytest.raises(Exception):
|
||||
|
@ -618,6 +627,8 @@ def test_random_id_generation(ray_start_regular_shared):
|
|||
assert f1._actor_id != f2._actor_id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="differing inheritence structure")
|
||||
def test_actor_inheritance(ray_start_regular_shared):
|
||||
class NonActorBase:
|
||||
def __init__(self):
|
||||
|
@ -630,8 +641,7 @@ def test_actor_inheritance(ray_start_regular_shared):
|
|||
pass
|
||||
|
||||
# Test that you can't instantiate an actor class directly.
|
||||
with pytest.raises(
|
||||
Exception, match="Actors cannot be instantiated directly."):
|
||||
with pytest.raises(Exception, match="cannot be instantiated directly"):
|
||||
ActorBase()
|
||||
|
||||
# Test that you can't inherit from an actor class.
|
||||
|
@ -645,6 +655,7 @@ def test_actor_inheritance(ray_start_regular_shared):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_multiple_return_values(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
@ -678,6 +689,7 @@ def test_multiple_return_values(ray_start_regular_shared):
|
|||
assert ray.get([id3a, id3b, id3c]) == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_options_num_returns(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
@ -693,6 +705,7 @@ def test_options_num_returns(ray_start_regular_shared):
|
|||
assert ray.get([obj1, obj2]) == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_options_name(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
|
@ -734,13 +747,13 @@ def test_actor_deletion(ray_start_regular_shared):
|
|||
a = Actor.remote()
|
||||
pid = ray.get(a.getpid.remote())
|
||||
a = None
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
wait_for_pid_to_exit(pid)
|
||||
|
||||
actors = [Actor.remote() for _ in range(10)]
|
||||
pids = ray.get([a.getpid.remote() for a in actors])
|
||||
a = None
|
||||
actors = None
|
||||
[ray.test_utils.wait_for_pid_to_exit(pid) for pid in pids]
|
||||
[wait_for_pid_to_exit(pid) for pid in pids]
|
||||
|
||||
|
||||
def test_actor_method_deletion(ray_start_regular_shared):
|
||||
|
@ -769,7 +782,8 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
|
|||
ray.get(signal.wait.remote())
|
||||
return ray.get(actor.method.remote())
|
||||
|
||||
signal = ray.test_utils.SignalActor.remote()
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signal = SignalActor.remote()
|
||||
a = Actor.remote()
|
||||
pid = ray.get(a.getpid.remote())
|
||||
# Pass the handle to another task that cannot run yet.
|
||||
|
@ -780,7 +794,7 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
|
|||
# Once the task finishes, the actor process should get killed.
|
||||
ray.get(signal.send.remote())
|
||||
assert ray.get(x_id) == 1
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
wait_for_pid_to_exit(pid)
|
||||
|
||||
|
||||
def test_multiple_actors(ray_start_regular_shared):
|
||||
|
@ -921,7 +935,7 @@ def test_atexit_handler(ray_start_regular_shared, exit_condition):
|
|||
if exit_condition == "ray.kill":
|
||||
assert not check_file_written()
|
||||
else:
|
||||
ray.test_utils.wait_for_condition(check_file_written)
|
||||
wait_for_condition(check_file_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,14 +8,23 @@ import time
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
from ray.test_utils import (
|
||||
client_test_enabled,
|
||||
dicts_equal,
|
||||
wait_for_pid_to_exit,
|
||||
)
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/6662
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_ignore_http_proxy(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
os.environ["http_proxy"] = "http://example.com"
|
||||
|
@ -29,6 +38,7 @@ def test_ignore_http_proxy(shutdown_only):
|
|||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/7263
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_grpc_message_size(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
|
@ -45,12 +55,14 @@ def test_grpc_message_size(shutdown_only):
|
|||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/7287
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_omp_threads_set(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
# Should have been auto set by ray init.
|
||||
assert os.environ["OMP_NUM_THREADS"] == "1"
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_submit_api(shutdown_only):
|
||||
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
|
||||
|
||||
|
@ -109,6 +121,7 @@ def test_submit_api(shutdown_only):
|
|||
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_invalid_arguments(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
|
@ -163,6 +176,7 @@ def test_invalid_arguments(shutdown_only):
|
|||
x = 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_many_fractional_resources(shutdown_only):
|
||||
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
|
||||
|
||||
|
@ -178,7 +192,7 @@ def test_many_fractional_resources(shutdown_only):
|
|||
}
|
||||
if block:
|
||||
ray.get(g.remote())
|
||||
return ray.test_utils.dicts_equal(true_resources, accepted_resources)
|
||||
return dicts_equal(true_resources, accepted_resources)
|
||||
|
||||
# Check that the resource are assigned correctly.
|
||||
result_ids = []
|
||||
|
@ -230,6 +244,7 @@ def test_many_fractional_resources(shutdown_only):
|
|||
assert False, "Did not get correct available resources."
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_background_tasks_with_max_calls(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
|
@ -257,7 +272,7 @@ def test_background_tasks_with_max_calls(shutdown_only):
|
|||
pid, g_id = nested.pop(0)
|
||||
ray.get(g_id)
|
||||
del g_id
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
wait_for_pid_to_exit(pid)
|
||||
|
||||
|
||||
def test_fair_queueing(shutdown_only):
|
||||
|
@ -327,6 +342,7 @@ def test_wait_timing(shutdown_only):
|
|||
assert len(not_ready) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet")
|
||||
def test_function_descriptor():
|
||||
python_descriptor = ray._raylet.PythonFunctionDescriptor(
|
||||
"module_name", "function_name", "class_name", "function_hash")
|
||||
|
@ -344,6 +360,7 @@ def test_function_descriptor():
|
|||
assert d.get(python_descriptor2) == 123
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_ray_options(shutdown_only):
|
||||
@ray.remote(
|
||||
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
|
||||
|
@ -371,6 +388,7 @@ def test_ray_options(shutdown_only):
|
|||
assert without_options != with_options
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"num_cpus": 0,
|
||||
|
@ -438,8 +456,11 @@ def test_nested_functions(ray_start_shared_local_modes):
|
|||
assert ray.get(factorial.remote(4)) == 24
|
||||
assert ray.get(factorial.remote(5)) == 120
|
||||
|
||||
# Test remote functions that recursively call each other.
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="mutual recursion is a known issue")
|
||||
def test_mutually_recursive_functions(ray_start_shared_local_modes):
|
||||
# Test remote functions that recursively call each other.
|
||||
@ray.remote
|
||||
def factorial_even(n):
|
||||
assert n % 2 == 0
|
||||
|
@ -710,6 +731,7 @@ def test_args_stars_after(ray_start_shared_local_modes):
|
|||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_object_id_backward_compatibility(ray_start_shared_local_modes):
|
||||
# We've renamed Python's `ObjectID` to `ObjectRef`, and added a type
|
||||
# alias for backward compatibility.
|
||||
|
|
|
@ -9,10 +9,16 @@ import pytest
|
|||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
from ray.test_utils import client_test_enabled
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -25,6 +31,8 @@ logger = logging.getLogger(__name__)
|
|||
}],
|
||||
indirect=True)
|
||||
def test_variable_number_of_args(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def varargs_fct1(*a):
|
||||
return " ".join(map(str, a))
|
||||
|
@ -33,8 +41,6 @@ def test_variable_number_of_args(shutdown_only):
|
|||
def varargs_fct2(a, *b):
|
||||
return " ".join(map(str, b))
|
||||
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
x = varargs_fct1.remote(0, 1, 2)
|
||||
assert ray.get(x) == "0 1 2"
|
||||
x = varargs_fct2.remote(0, 1, 2)
|
||||
|
@ -160,7 +166,7 @@ def test_redefining_remote_functions(shutdown_only):
|
|||
def g():
|
||||
return nonexistent()
|
||||
|
||||
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
|
||||
with pytest.raises(RayTaskError, match="nonexistent"):
|
||||
ray.get(g.remote())
|
||||
|
||||
def nonexistent():
|
||||
|
@ -187,6 +193,7 @@ def test_redefining_remote_functions(shutdown_only):
|
|||
assert ray.get(ray.get(h.remote(i))) == i
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_call_matrix(shutdown_only):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
|
||||
|
@ -312,6 +319,7 @@ def test_actor_pass_by_ref_order_optimization(shutdown_only):
|
|||
assert delta < 10, "did not skip slow value"
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster", [{
|
||||
"num_cpus": 1,
|
||||
|
@ -332,6 +340,7 @@ def test_call_chain(ray_start_cluster):
|
|||
assert ray.get(x) == 100
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_system_config_when_connecting(ray_start_cluster):
|
||||
config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200}
|
||||
cluster = ray.cluster_utils.Cluster()
|
||||
|
@ -368,7 +377,8 @@ def test_get_multiple(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_get_with_timeout(ray_start_regular_shared):
|
||||
signal = ray.test_utils.SignalActor.remote()
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signal = SignalActor.remote()
|
||||
|
||||
# Check that get() returns early if object is ready.
|
||||
start = time.time()
|
||||
|
@ -438,6 +448,7 @@ def test_inline_arg_memory_corruption(ray_start_regular_shared):
|
|||
ray.get(a.add.remote(f.remote()))
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_skip_plasma(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
|
@ -454,6 +465,8 @@ def test_skip_plasma(ray_start_regular_shared):
|
|||
assert ray.get(obj_ref) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="internal api and message size")
|
||||
def test_actor_large_objects(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
|
@ -524,6 +537,7 @@ def test_actor_recursive(ray_start_regular_shared):
|
|||
assert result == [x * 2 for x in range(100)]
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="remote args")
|
||||
def test_actor_concurrent(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Batcher:
|
||||
|
@ -626,6 +640,7 @@ def test_duplicate_args(ray_start_regular_shared):
|
|||
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_get_correct_node_ip():
|
||||
with patch("ray.worker") as worker_mock:
|
||||
node_mock = MagicMock()
|
||||
|
|
|
@ -81,11 +81,11 @@ def test_wait(ray_start_regular_shared):
|
|||
with pytest.raises(Exception):
|
||||
# Reference not in the object store.
|
||||
ray.wait([ClientObjectRef("blabla")])
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(TypeError):
|
||||
ray.wait("blabla")
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(TypeError):
|
||||
ray.wait(ClientObjectRef("blabla"))
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(TypeError):
|
||||
ray.wait(["blabla"])
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from ray.tests.test_experimental_client import ray_start_client_server
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.exceptions import TaskCancelledError
|
||||
from ray.exceptions import RayTaskError
|
||||
|
@ -45,21 +45,7 @@ def test_kill_actor_immediately_after_creation(ray_start_regular):
|
|||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_chain(ray_start_regular, use_force):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class SignalActor:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self, clear=False):
|
||||
self.ready_event.set()
|
||||
if clear:
|
||||
self.ready_event.clear()
|
||||
|
||||
async def wait(self, should_wait=True):
|
||||
if should_wait:
|
||||
await self.ready_event.wait()
|
||||
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
|
|
|
@ -50,16 +50,22 @@ message ClientTask {
|
|||
string name = 2;
|
||||
// A reference to the payload.
|
||||
bytes payload_id = 3;
|
||||
// The parameters to pass to this call.
|
||||
// Positional parameters to pass to this call.
|
||||
repeated Arg args = 4;
|
||||
// Keyword parameters to pass to this call.
|
||||
map<string, Arg> kwargs = 5;
|
||||
// The ID of the client namespace associated with the Datapath stream making this
|
||||
// request.
|
||||
string client_id = 5;
|
||||
string client_id = 6;
|
||||
}
|
||||
|
||||
message ClientTaskTicket {
|
||||
// Was the task successful?
|
||||
bool valid = 1;
|
||||
// A reference to the returned value from the execution.
|
||||
bytes return_id = 1;
|
||||
bytes return_id = 2;
|
||||
// If unsuccessful, an encoding of the error.
|
||||
bytes error = 3;
|
||||
}
|
||||
|
||||
// Delivers data to the server
|
||||
|
|
Loading…
Add table
Reference in a new issue