[ray client] ray wait() implementation (#12072)

This commit is contained in:
Ameer Haj Ali 2020-11-18 21:33:57 +02:00 committed by GitHub
parent 2b60c5774b
commit eef624750c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 163 additions and 1 deletions

View file

@ -30,6 +30,12 @@ def put(*args, **kwargs):
return _client_api.put(*args, **kwargs) return _client_api.put(*args, **kwargs)
def wait(*args, **kwargs):
global _client_api
check_client_api()
return _client_api.wait(*args, **kwargs)
def remote(*args, **kwargs): def remote(*args, **kwargs):
global _client_api global _client_api
check_client_api() check_client_api()

View file

@ -22,6 +22,10 @@ class APIImpl(ABC):
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
pass pass
@abstractmethod
def wait(self, *args, **kwargs):
pass
@abstractmethod @abstractmethod
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
pass pass
@ -45,6 +49,9 @@ class ClientAPI(APIImpl):
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
return self.worker.put(*args, **kwargs) return self.worker.put(*args, **kwargs)
def wait(self, *args, **kwargs):
return self.worker.wait(*args, **kwargs)
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs) return self.worker.remote(*args, **kwargs)

View file

@ -43,3 +43,18 @@ print(ray.get(ref3))
ref4 = fact.remote(5) ref4 = fact.remote(5)
# `120` # `120`
print(ray.get(ref4)) print(ray.get(ref4))
ref5 = fact.remote(10)
print([ref2, ref3, ref4, ref5])
# should return ref2, ref3, ref4
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
print(res)
assert [ref2, ref3, ref4] == res[0]
assert [ref5] == res[1]
# should return ref2, ref3, ref4, ref5
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
print(res)
assert [ref2, ref3, ref4, ref5] == res[0]
assert [] == res[1]

View file

@ -11,6 +11,9 @@ class ClientObjectRef:
def __repr__(self): def __repr__(self):
return "ClientObjectRef(%s)" % self.id.hex() return "ClientObjectRef(%s)" % self.id.hex()
def __eq__(self, other):
return self.id == other.id
class ClientRemoteFunc: class ClientRemoteFunc:
def __init__(self, f): def __init__(self, f):

View file

@ -19,6 +19,9 @@ class CoreRayAPI(APIImpl):
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
return ray.put(*args, **kwargs) return ray.put(*args, **kwargs)
def wait(self, *args, **kwargs):
return ray.wait(*args, **kwargs)
def remote(self, *args, **kwargs): def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs) return ray.remote(*args, **kwargs)

View file

@ -35,6 +35,37 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
logger.info("put: %s" % objectref) logger.info("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary()) return ray_client_pb2.PutResponse(id=objectref.binary())
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_refs]
num_returns = request.num_returns
timeout = request.timeout
object_refs_ids = []
for object_ref in object_refs:
if object_ref.id not in self.object_refs:
return ray_client_pb2.WaitResponse(valid=False)
object_refs_ids.append(self.object_refs[object_ref.id])
try:
ready_object_refs, remaining_object_refs = ray.wait(
object_refs_ids,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None)
except Exception:
# TODO(ameer): improve exception messages.
return ray_client_pb2.WaitResponse(valid=False)
logger.info("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
ready_object_ids = [
ready_object_ref.binary() for ready_object_ref in ready_object_refs
]
remaining_object_ids = [
remaining_object_ref.binary()
for remaining_object_ref in remaining_object_refs
]
return ray_client_pb2.WaitResponse(
valid=True,
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None): def Schedule(self, task, context=None):
logger.info("schedule: %s" % task) logger.info("schedule: %s" % task)
if task.payload_id not in self.function_refs: if task.payload_id not in self.function_refs:

View file

@ -1,5 +1,12 @@
from ray import cloudpickle """This file includes the Worker class which sits on the client side.
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
from typing import List
import grpc import grpc
from ray import cloudpickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg from ray.experimental.client.common import convert_to_arg
@ -59,6 +66,36 @@ class Worker:
resp = self.server.PutObject(req) resp = self.server.PutObject(req)
return ClientObjectRef(resp.id) return ClientObjectRef(resp.id)
def wait(self,
object_refs: List[ClientObjectRef],
*,
num_returns: int = 1,
timeout: float = None
) -> (List[ClientObjectRef], List[ClientObjectRef]):
assert isinstance(object_refs, list)
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
data = {
"object_refs": [
cloudpickle.dumps(object_ref) for object_ref in object_refs
],
"num_returns": num_returns,
"timeout": timeout if timeout else -1
}
req = ray_client_pb2.WaitRequest(**data)
resp = self.server.WaitObject(req)
if not resp.valid:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
client_ready_object_ids = [
ClientObjectRef(id) for id in resp.ready_object_ids
]
client_remaining_object_ids = [
ClientObjectRef(id) for id in resp.remaining_object_ids
]
return (client_ready_object_ids, client_remaining_object_ids)
def remote(self, func): def remote(self, func):
return ClientRemoteFunc(func) return ClientRemoteFunc(func)

View file

@ -1,6 +1,7 @@
import pytest import pytest
import ray.experimental.client.server as ray_client_server import ray.experimental.client.server as ray_client_server
import ray.experimental.client as ray import ray.experimental.client as ray
from ray.experimental.client.common import ClientObjectRef
def test_put_get(ray_start_regular_shared): def test_put_get(ray_start_regular_shared):
@ -16,6 +17,39 @@ def test_put_get(ray_start_regular_shared):
server.stop(0) server.stop(0)
def test_wait(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
objectref = ray.put("hello world")
ready, remaining = ray.wait([objectref])
assert remaining == []
retval = ray.get(ready[0])
assert retval == "hello world"
objectref2 = ray.put(5)
ready, remaining = ray.wait([objectref, objectref2])
assert (ready, remaining) == ([objectref], [objectref2]) or \
(ready, remaining) == ([objectref2], [objectref])
ready_retval = ray.get(ready[0])
remaining_retval = ray.get(remaining[0])
assert (ready_retval, remaining_retval) == ("hello world", 5) \
or (ready_retval, remaining_retval) == (5, "hello world")
with pytest.raises(Exception):
# Reference not in the object store.
ray.wait([ClientObjectRef("blabla")])
with pytest.raises(AssertionError):
ray.wait("blabla")
with pytest.raises(AssertionError):
ray.wait(ClientObjectRef("blabla"))
with pytest.raises(AssertionError):
ray.wait(["blabla"])
ray.disconnect()
server.stop(0)
def test_remote_functions(ray_start_regular_shared): def test_remote_functions(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50051") server = ray_client_server.serve("localhost:50051")
@ -45,6 +79,19 @@ def test_remote_functions(ray_start_regular_shared):
ref4 = fact.remote(5) ref4 = fact.remote(5)
assert ray.get(ref4) == 120 assert ray.get(ref4) == 120
# Test ray.wait()
ref5 = fact.remote(10)
# should return ref2, ref3, ref4
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
assert [ref2, ref3, ref4] == res[0]
assert [ref5] == res[1]
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
# should return ref2, ref3, ref4, ref5
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
assert [ref2, ref3, ref4, ref5] == res[0]
assert [] == res[1]
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120, 3628800]
ray.disconnect() ray.disconnect()
server.stop(0) server.stop(0)

View file

@ -56,12 +56,25 @@ message GetResponse {
bool valid = 1; bool valid = 1;
bytes data = 2; bytes data = 2;
} }
message WaitRequest {
repeated bytes object_refs = 1;
int64 num_returns = 2;
double timeout = 3;
}
message WaitResponse {
bool valid = 1;
repeated bytes ready_object_ids = 2;
repeated bytes remaining_object_ids = 3;
}
service RayletDriver { service RayletDriver {
rpc GetObject(GetRequest) returns (GetResponse) { rpc GetObject(GetRequest) returns (GetResponse) {
} }
rpc PutObject(PutRequest) returns (PutResponse) { rpc PutObject(PutRequest) returns (PutResponse) {
} }
rpc WaitObject(WaitRequest) returns (WaitResponse) {
}
rpc Schedule(ClientTask) returns (ClientTaskTicket) { rpc Schedule(ClientTask) returns (ClientTaskTicket) {
} }
} }