mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[ray client] ray wait() implementation (#12072)
This commit is contained in:
parent
2b60c5774b
commit
eef624750c
9 changed files with 163 additions and 1 deletions
|
@ -30,6 +30,12 @@ def put(*args, **kwargs):
|
|||
return _client_api.put(*args, **kwargs)
|
||||
|
||||
|
||||
def wait(*args, **kwargs):
|
||||
global _client_api
|
||||
check_client_api()
|
||||
return _client_api.wait(*args, **kwargs)
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
global _client_api
|
||||
check_client_api()
|
||||
|
|
|
@ -22,6 +22,10 @@ class APIImpl(ABC):
|
|||
def put(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
pass
|
||||
|
@ -45,6 +49,9 @@ class ClientAPI(APIImpl):
|
|||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -43,3 +43,18 @@ print(ray.get(ref3))
|
|||
ref4 = fact.remote(5)
|
||||
# `120`
|
||||
print(ray.get(ref4))
|
||||
|
||||
ref5 = fact.remote(10)
|
||||
|
||||
print([ref2, ref3, ref4, ref5])
|
||||
# should return ref2, ref3, ref4
|
||||
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||
print(res)
|
||||
assert [ref2, ref3, ref4] == res[0]
|
||||
assert [ref5] == res[1]
|
||||
|
||||
# should return ref2, ref3, ref4, ref5
|
||||
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||
print(res)
|
||||
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||
assert [] == res[1]
|
||||
|
|
|
@ -11,6 +11,9 @@ class ClientObjectRef:
|
|||
def __repr__(self):
|
||||
return "ClientObjectRef(%s)" % self.id.hex()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
class ClientRemoteFunc:
|
||||
def __init__(self, f):
|
||||
|
|
|
@ -19,6 +19,9 @@ class CoreRayAPI(APIImpl):
|
|||
def put(self, *args, **kwargs):
|
||||
return ray.put(*args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return ray.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -35,6 +35,37 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
logger.info("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
|
||||
num_returns = request.num_returns
|
||||
timeout = request.timeout
|
||||
object_refs_ids = []
|
||||
for object_ref in object_refs:
|
||||
if object_ref.id not in self.object_refs:
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
object_refs_ids.append(self.object_refs[object_ref.id])
|
||||
try:
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs_ids,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None)
|
||||
except Exception:
|
||||
# TODO(ameer): improve exception messages.
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
logger.info("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
ready_object_ids = [
|
||||
ready_object_ref.binary() for ready_object_ref in ready_object_refs
|
||||
]
|
||||
remaining_object_ids = [
|
||||
remaining_object_ref.binary()
|
||||
for remaining_object_ref in remaining_object_refs
|
||||
]
|
||||
return ray_client_pb2.WaitResponse(
|
||||
valid=True,
|
||||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
logger.info("schedule: %s" % task)
|
||||
if task.payload_id not in self.function_refs:
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
from ray import cloudpickle
|
||||
"""This file includes the Worker class which sits on the client side.
|
||||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import grpc
|
||||
|
||||
from ray import cloudpickle
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
|
@ -59,6 +66,36 @@ class Worker:
|
|||
resp = self.server.PutObject(req)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
*,
|
||||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
) -> (List[ClientObjectRef], List[ClientObjectRef]):
|
||||
assert isinstance(object_refs, list)
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
data = {
|
||||
"object_refs": [
|
||||
cloudpickle.dumps(object_ref) for object_ref in object_refs
|
||||
],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req)
|
||||
if not resp.valid:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
client_ready_object_ids = [
|
||||
ClientObjectRef(id) for id in resp.ready_object_ids
|
||||
]
|
||||
client_remaining_object_ids = [
|
||||
ClientObjectRef(id) for id in resp.remaining_object_ids
|
||||
]
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, func):
|
||||
return ClientRemoteFunc(func)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
import ray.experimental.client.server as ray_client_server
|
||||
import ray.experimental.client as ray
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
def test_put_get(ray_start_regular_shared):
|
||||
|
@ -16,6 +17,39 @@ def test_put_get(ray_start_regular_shared):
|
|||
server.stop(0)
|
||||
|
||||
|
||||
def test_wait(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
ready, remaining = ray.wait([objectref])
|
||||
assert remaining == []
|
||||
retval = ray.get(ready[0])
|
||||
assert retval == "hello world"
|
||||
|
||||
objectref2 = ray.put(5)
|
||||
ready, remaining = ray.wait([objectref, objectref2])
|
||||
assert (ready, remaining) == ([objectref], [objectref2]) or \
|
||||
(ready, remaining) == ([objectref2], [objectref])
|
||||
ready_retval = ray.get(ready[0])
|
||||
remaining_retval = ray.get(remaining[0])
|
||||
assert (ready_retval, remaining_retval) == ("hello world", 5) \
|
||||
or (ready_retval, remaining_retval) == (5, "hello world")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# Reference not in the object store.
|
||||
ray.wait([ClientObjectRef("blabla")])
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait("blabla")
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait(ClientObjectRef("blabla"))
|
||||
with pytest.raises(AssertionError):
|
||||
ray.wait(["blabla"])
|
||||
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_remote_functions(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
|
||||
|
@ -45,6 +79,19 @@ def test_remote_functions(ray_start_regular_shared):
|
|||
ref4 = fact.remote(5)
|
||||
assert ray.get(ref4) == 120
|
||||
|
||||
# Test ray.wait()
|
||||
ref5 = fact.remote(10)
|
||||
# should return ref2, ref3, ref4
|
||||
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||
assert [ref2, ref3, ref4] == res[0]
|
||||
assert [ref5] == res[1]
|
||||
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
|
||||
# should return ref2, ref3, ref4, ref5
|
||||
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||
assert [] == res[1]
|
||||
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800]
|
||||
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
|
||||
|
|
|
@ -56,12 +56,25 @@ message GetResponse {
|
|||
bool valid = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
message WaitRequest {
|
||||
repeated bytes object_refs = 1;
|
||||
int64 num_returns = 2;
|
||||
double timeout = 3;
|
||||
}
|
||||
|
||||
message WaitResponse {
|
||||
bool valid = 1;
|
||||
repeated bytes ready_object_ids = 2;
|
||||
repeated bytes remaining_object_ids = 3;
|
||||
}
|
||||
|
||||
service RayletDriver {
|
||||
rpc GetObject(GetRequest) returns (GetResponse) {
|
||||
}
|
||||
rpc PutObject(PutRequest) returns (PutResponse) {
|
||||
}
|
||||
rpc WaitObject(WaitRequest) returns (WaitResponse) {
|
||||
}
|
||||
rpc Schedule(ClientTask) returns (ClientTaskTicket) {
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue