mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[ray client] ray wait() implementation (#12072)
This commit is contained in:
parent
2b60c5774b
commit
eef624750c
9 changed files with 163 additions and 1 deletions
|
@ -30,6 +30,12 @@ def put(*args, **kwargs):
|
||||||
return _client_api.put(*args, **kwargs)
|
return _client_api.put(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def wait(*args, **kwargs):
|
||||||
|
global _client_api
|
||||||
|
check_client_api()
|
||||||
|
return _client_api.wait(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def remote(*args, **kwargs):
|
def remote(*args, **kwargs):
|
||||||
global _client_api
|
global _client_api
|
||||||
check_client_api()
|
check_client_api()
|
||||||
|
|
|
@ -22,6 +22,10 @@ class APIImpl(ABC):
|
||||||
def put(self, *args, **kwargs):
|
def put(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -45,6 +49,9 @@ class ClientAPI(APIImpl):
|
||||||
def put(self, *args, **kwargs):
|
def put(self, *args, **kwargs):
|
||||||
return self.worker.put(*args, **kwargs)
|
return self.worker.put(*args, **kwargs)
|
||||||
|
|
||||||
|
def wait(self, *args, **kwargs):
|
||||||
|
return self.worker.wait(*args, **kwargs)
|
||||||
|
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
return self.worker.remote(*args, **kwargs)
|
return self.worker.remote(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -43,3 +43,18 @@ print(ray.get(ref3))
|
||||||
ref4 = fact.remote(5)
|
ref4 = fact.remote(5)
|
||||||
# `120`
|
# `120`
|
||||||
print(ray.get(ref4))
|
print(ray.get(ref4))
|
||||||
|
|
||||||
|
ref5 = fact.remote(10)
|
||||||
|
|
||||||
|
print([ref2, ref3, ref4, ref5])
|
||||||
|
# should return ref2, ref3, ref4
|
||||||
|
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||||
|
print(res)
|
||||||
|
assert [ref2, ref3, ref4] == res[0]
|
||||||
|
assert [ref5] == res[1]
|
||||||
|
|
||||||
|
# should return ref2, ref3, ref4, ref5
|
||||||
|
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||||
|
print(res)
|
||||||
|
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||||
|
assert [] == res[1]
|
||||||
|
|
|
@ -11,6 +11,9 @@ class ClientObjectRef:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "ClientObjectRef(%s)" % self.id.hex()
|
return "ClientObjectRef(%s)" % self.id.hex()
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.id == other.id
|
||||||
|
|
||||||
|
|
||||||
class ClientRemoteFunc:
|
class ClientRemoteFunc:
|
||||||
def __init__(self, f):
|
def __init__(self, f):
|
||||||
|
|
|
@ -19,6 +19,9 @@ class CoreRayAPI(APIImpl):
|
||||||
def put(self, *args, **kwargs):
|
def put(self, *args, **kwargs):
|
||||||
return ray.put(*args, **kwargs)
|
return ray.put(*args, **kwargs)
|
||||||
|
|
||||||
|
def wait(self, *args, **kwargs):
|
||||||
|
return ray.wait(*args, **kwargs)
|
||||||
|
|
||||||
def remote(self, *args, **kwargs):
|
def remote(self, *args, **kwargs):
|
||||||
return ray.remote(*args, **kwargs)
|
return ray.remote(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,37 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
logger.info("put: %s" % objectref)
|
logger.info("put: %s" % objectref)
|
||||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||||
|
|
||||||
|
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||||
|
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
|
||||||
|
num_returns = request.num_returns
|
||||||
|
timeout = request.timeout
|
||||||
|
object_refs_ids = []
|
||||||
|
for object_ref in object_refs:
|
||||||
|
if object_ref.id not in self.object_refs:
|
||||||
|
return ray_client_pb2.WaitResponse(valid=False)
|
||||||
|
object_refs_ids.append(self.object_refs[object_ref.id])
|
||||||
|
try:
|
||||||
|
ready_object_refs, remaining_object_refs = ray.wait(
|
||||||
|
object_refs_ids,
|
||||||
|
num_returns=num_returns,
|
||||||
|
timeout=timeout if timeout != -1 else None)
|
||||||
|
except Exception:
|
||||||
|
# TODO(ameer): improve exception messages.
|
||||||
|
return ray_client_pb2.WaitResponse(valid=False)
|
||||||
|
logger.info("wait: %s %s" % (str(ready_object_refs),
|
||||||
|
str(remaining_object_refs)))
|
||||||
|
ready_object_ids = [
|
||||||
|
ready_object_ref.binary() for ready_object_ref in ready_object_refs
|
||||||
|
]
|
||||||
|
remaining_object_ids = [
|
||||||
|
remaining_object_ref.binary()
|
||||||
|
for remaining_object_ref in remaining_object_refs
|
||||||
|
]
|
||||||
|
return ray_client_pb2.WaitResponse(
|
||||||
|
valid=True,
|
||||||
|
ready_object_ids=ready_object_ids,
|
||||||
|
remaining_object_ids=remaining_object_ids)
|
||||||
|
|
||||||
def Schedule(self, task, context=None):
|
def Schedule(self, task, context=None):
|
||||||
logger.info("schedule: %s" % task)
|
logger.info("schedule: %s" % task)
|
||||||
if task.payload_id not in self.function_refs:
|
if task.payload_id not in self.function_refs:
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
from ray import cloudpickle
|
"""This file includes the Worker class which sits on the client side.
|
||||||
|
It implements the Ray API functions that are forwarded through grpc calls
|
||||||
|
to the server.
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
|
from ray import cloudpickle
|
||||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
from ray.experimental.client.common import convert_to_arg
|
from ray.experimental.client.common import convert_to_arg
|
||||||
|
@ -59,6 +66,36 @@ class Worker:
|
||||||
resp = self.server.PutObject(req)
|
resp = self.server.PutObject(req)
|
||||||
return ClientObjectRef(resp.id)
|
return ClientObjectRef(resp.id)
|
||||||
|
|
||||||
|
def wait(self,
|
||||||
|
object_refs: List[ClientObjectRef],
|
||||||
|
*,
|
||||||
|
num_returns: int = 1,
|
||||||
|
timeout: float = None
|
||||||
|
) -> (List[ClientObjectRef], List[ClientObjectRef]):
|
||||||
|
assert isinstance(object_refs, list)
|
||||||
|
for ref in object_refs:
|
||||||
|
assert isinstance(ref, ClientObjectRef)
|
||||||
|
data = {
|
||||||
|
"object_refs": [
|
||||||
|
cloudpickle.dumps(object_ref) for object_ref in object_refs
|
||||||
|
],
|
||||||
|
"num_returns": num_returns,
|
||||||
|
"timeout": timeout if timeout else -1
|
||||||
|
}
|
||||||
|
req = ray_client_pb2.WaitRequest(**data)
|
||||||
|
resp = self.server.WaitObject(req)
|
||||||
|
if not resp.valid:
|
||||||
|
# TODO(ameer): improve error/exceptions messages.
|
||||||
|
raise Exception("Client Wait request failed. Reference invalid?")
|
||||||
|
client_ready_object_ids = [
|
||||||
|
ClientObjectRef(id) for id in resp.ready_object_ids
|
||||||
|
]
|
||||||
|
client_remaining_object_ids = [
|
||||||
|
ClientObjectRef(id) for id in resp.remaining_object_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return (client_ready_object_ids, client_remaining_object_ids)
|
||||||
|
|
||||||
def remote(self, func):
|
def remote(self, func):
|
||||||
return ClientRemoteFunc(func)
|
return ClientRemoteFunc(func)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import ray.experimental.client.server as ray_client_server
|
import ray.experimental.client.server as ray_client_server
|
||||||
import ray.experimental.client as ray
|
import ray.experimental.client as ray
|
||||||
|
from ray.experimental.client.common import ClientObjectRef
|
||||||
|
|
||||||
|
|
||||||
def test_put_get(ray_start_regular_shared):
|
def test_put_get(ray_start_regular_shared):
|
||||||
|
@ -16,6 +17,39 @@ def test_put_get(ray_start_regular_shared):
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait(ray_start_regular_shared):
|
||||||
|
server = ray_client_server.serve("localhost:50051")
|
||||||
|
ray.connect("localhost:50051")
|
||||||
|
|
||||||
|
objectref = ray.put("hello world")
|
||||||
|
ready, remaining = ray.wait([objectref])
|
||||||
|
assert remaining == []
|
||||||
|
retval = ray.get(ready[0])
|
||||||
|
assert retval == "hello world"
|
||||||
|
|
||||||
|
objectref2 = ray.put(5)
|
||||||
|
ready, remaining = ray.wait([objectref, objectref2])
|
||||||
|
assert (ready, remaining) == ([objectref], [objectref2]) or \
|
||||||
|
(ready, remaining) == ([objectref2], [objectref])
|
||||||
|
ready_retval = ray.get(ready[0])
|
||||||
|
remaining_retval = ray.get(remaining[0])
|
||||||
|
assert (ready_retval, remaining_retval) == ("hello world", 5) \
|
||||||
|
or (ready_retval, remaining_retval) == (5, "hello world")
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
# Reference not in the object store.
|
||||||
|
ray.wait([ClientObjectRef("blabla")])
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
ray.wait("blabla")
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
ray.wait(ClientObjectRef("blabla"))
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
ray.wait(["blabla"])
|
||||||
|
|
||||||
|
ray.disconnect()
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
def test_remote_functions(ray_start_regular_shared):
|
def test_remote_functions(ray_start_regular_shared):
|
||||||
server = ray_client_server.serve("localhost:50051")
|
server = ray_client_server.serve("localhost:50051")
|
||||||
|
|
||||||
|
@ -45,6 +79,19 @@ def test_remote_functions(ray_start_regular_shared):
|
||||||
ref4 = fact.remote(5)
|
ref4 = fact.remote(5)
|
||||||
assert ray.get(ref4) == 120
|
assert ray.get(ref4) == 120
|
||||||
|
|
||||||
|
# Test ray.wait()
|
||||||
|
ref5 = fact.remote(10)
|
||||||
|
# should return ref2, ref3, ref4
|
||||||
|
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||||
|
assert [ref2, ref3, ref4] == res[0]
|
||||||
|
assert [ref5] == res[1]
|
||||||
|
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
|
||||||
|
# should return ref2, ref3, ref4, ref5
|
||||||
|
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||||
|
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||||
|
assert [] == res[1]
|
||||||
|
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800]
|
||||||
|
|
||||||
ray.disconnect()
|
ray.disconnect()
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
|
||||||
|
|
|
@ -56,12 +56,25 @@ message GetResponse {
|
||||||
bool valid = 1;
|
bool valid = 1;
|
||||||
bytes data = 2;
|
bytes data = 2;
|
||||||
}
|
}
|
||||||
|
message WaitRequest {
|
||||||
|
repeated bytes object_refs = 1;
|
||||||
|
int64 num_returns = 2;
|
||||||
|
double timeout = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message WaitResponse {
|
||||||
|
bool valid = 1;
|
||||||
|
repeated bytes ready_object_ids = 2;
|
||||||
|
repeated bytes remaining_object_ids = 3;
|
||||||
|
}
|
||||||
|
|
||||||
service RayletDriver {
|
service RayletDriver {
|
||||||
rpc GetObject(GetRequest) returns (GetResponse) {
|
rpc GetObject(GetRequest) returns (GetResponse) {
|
||||||
}
|
}
|
||||||
rpc PutObject(PutRequest) returns (PutResponse) {
|
rpc PutObject(PutRequest) returns (PutResponse) {
|
||||||
}
|
}
|
||||||
|
rpc WaitObject(WaitRequest) returns (WaitResponse) {
|
||||||
|
}
|
||||||
rpc Schedule(ClientTask) returns (ClientTaskTicket) {
|
rpc Schedule(ClientTask) returns (ClientTaskTicket) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue