Revert "[Client] Make Client{ObjectRef,ActorRef} subclasses of their server-side counterparts (#16110)" (#16196)

This reverts commit f14f197d42.
This commit is contained in:
Alex Wu 2021-06-02 10:31:01 -07:00 committed by GitHub
parent 611da62739
commit 9942505b63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 95 additions and 153 deletions

View file

@ -99,7 +99,6 @@ sys.modules["pytorch_lightning"] = ChildClassMock()
sys.modules["xgboost"] = ChildClassMock()
sys.modules["xgboost.core"] = ChildClassMock()
sys.modules["xgboost.callback"] = ChildClassMock()
sys.modules["xgboost_ray"] = ChildClassMock()
class SimpleClass(object):

View file

@ -7,8 +7,6 @@ import logging
from typing import Callable, Any, Union
import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.util.client as client
logger = logging.getLogger(__name__)
@ -135,59 +133,3 @@ cdef class ObjectRef(BaseID):
core_worker = ray.worker.global_worker.core_worker
core_worker.set_get_async_callback(self, py_callback)
return self
cdef class ClientObjectRef(ObjectRef):
def __init__(self, id: bytes):
check_id(id)
self.data = CObjectID.FromBinary(<c_string>id)
client.ray.call_retain(id)
self.in_core_worker = False
def __dealloc__(self):
if client.ray.is_connected() and not self.data.IsNil():
client.ray.call_release(self.id)
@property
def id(self):
return self.binary()
def future(self) -> concurrent.futures.Future:
fut = concurrent.futures.Future()
def set_value(data: Any) -> None:
"""Schedules a callback to set the exception or result
in the Future."""
if isinstance(data, Exception):
fut.set_exception(data)
else:
fut.set_result(data)
self._on_completed(set_value)
# Prevent this object ref from being released.
fut.object_ref = self
return fut
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
"""Register a callback that will be called after Object is ready.
If the ObjectRef is already ready, the callback will be called soon.
The callback should take the result as the only argument. The result
can be an exception object in case of task error.
"""
from ray.util.client.client_pickler import loads_from_server
def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
"""Converts from a GetResponse proto to a python object."""
obj = resp.get
data = None
if not obj.valid:
data = loads_from_server(resp.get.error)
else:
data = loads_from_server(resp.get.data)
py_callback(data)
client.ray._register_callback(self, deserialize_obj)

View file

@ -259,7 +259,6 @@ cdef class WorkerID(UniqueID):
return <CWorkerID>self.data
cdef class ActorID(BaseID):
def __init__(self, id):
check_id(id, CActorID.Size())
self.data = CActorID.FromBinary(<c_string>id)
@ -303,22 +302,6 @@ cdef class ActorID(BaseID):
return self.data.Hash()
cdef class ClientActorRef(ActorID):
def __init__(self, id: bytes):
check_id(id, CActorID.Size())
self.data = CActorID.FromBinary(<c_string>id)
client.ray.call_retain(id)
def __dealloc__(self):
if client.ray.is_connected() and not self.data.IsNil():
client.ray.call_release(self.id)
@property
def id(self):
return self.binary()
cdef class FunctionID(UniqueID):
def __init__(self, id):

View file

@ -1,79 +1,11 @@
import sys
import pytest
from ray.util.client import RayAPIStub
from ray.util.client.common import ClientActorRef, ClientObjectRef
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray.util.client.ray_client_helpers import (
ray_start_client_server_pair, ray_start_cluster_client_server_pair)
from ray.test_utils import wait_for_condition
import ray as real_ray
from ray.core.generated.gcs_pb2 import ActorTableData
from ray._raylet import ActorID, ObjectRef
@pytest.mark.skipif(sys.platform == "win32", reason="Timing out on Windows.")
def test_client_object_ref_basics(ray_start_regular):
with ray_start_client_server_pair() as pair:
ray, server = pair
ref = ray.put("Hello World")
# Make sure ClientObjectRef is a subclass of ObjectRef
assert isinstance(ref, ClientObjectRef)
assert isinstance(ref, ObjectRef)
# Invalid ref format.
with pytest.raises(Exception):
ClientObjectRef(b"\0")
# Test __eq__()
id = b"\0" * 28
assert ClientObjectRef(id) == ClientObjectRef(id)
assert ClientObjectRef(id) != ref
assert ClientObjectRef(id) != ObjectRef(id)
assert ClientObjectRef(id).__repr__() == f"ClientObjectRef({id.hex()})"
assert ClientObjectRef(id).binary() == id
assert ClientObjectRef(id).hex() == id.hex()
assert not ClientObjectRef(id).is_nil()
@pytest.mark.skipif(sys.platform == "win32", reason="Timing out on Windows.")
def test_client_actor_ref_basics(ray_start_regular):
with ray_start_client_server_pair() as pair:
ray, server = pair
@ray.remote
class Counter:
def __init__(self):
self.acc = 0
def inc(self):
self.acc += 1
def get(self):
return self.acc
counter = Counter.remote()
ref = counter.actor_ref
# Make sure ClientActorRef is a subclass of ActorID
assert isinstance(ref, ClientActorRef)
assert isinstance(ref, ActorID)
# Invalid ref format.
with pytest.raises(Exception):
ClientActorRef(b"\0")
# Test __eq__()
id = b"\0" * 16
assert ClientActorRef(id) == ClientActorRef(id)
assert ClientActorRef(id) != ref
assert ClientActorRef(id) != ActorID(id)
assert ClientActorRef(id).__repr__() == f"ClientActorRef({id.hex()})"
assert ClientActorRef(id).binary() == id
assert ClientActorRef(id).hex() == id.hex()
assert not ClientActorRef(id).is_nil()
def server_object_ref_count(server, n):

View file

@ -19,7 +19,6 @@ For many of the objects in the root `ray` namespace, there is an equivalent clie
These objects are client stand-ins for their server-side objects. For example:
```
ObjectRef <-> ClientObjectRef
ActorID <-> ClientActorRef
RemoteFunc <-> ClientRemoteFunc
```

View file

@ -1,9 +1,10 @@
import ray._raylet as raylet
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import ray
from ray.util.client.options import validate_options
import asyncio
import concurrent.futures
from dataclasses import dataclass
import grpc
import os
@ -13,6 +14,7 @@ from ray.util.inspect import is_cython
import json
import threading
from typing import Any
from typing import Callable
from typing import List
from typing import Dict
from typing import Optional
@ -50,9 +52,87 @@ GRPC_OPTIONS = [
CLIENT_SERVER_MAX_THREADS = float(
os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))
# Aliases for compatibility.
ClientObjectRef = raylet.ClientObjectRef
ClientActorRef = raylet.ClientActorRef
class ClientBaseRef:
def __init__(self, id: bytes):
self.id = None
if not isinstance(id, bytes):
raise TypeError("ClientRefs must be created with bytes IDs")
self.id: bytes = id
ray.call_retain(id)
def binary(self):
return self.id
def hex(self):
return self.id.hex()
def __eq__(self, other):
return isinstance(other, ClientBaseRef) and self.id == other.id
def __repr__(self):
return "%s(%s)" % (
type(self).__name__,
self.id.hex(),
)
def __hash__(self):
return hash(self.id)
def __del__(self):
if ray.is_connected() and self.id is not None:
ray.call_release(self.id)
class ClientObjectRef(ClientBaseRef):
def __await__(self):
return self.as_future().__await__()
def as_future(self) -> asyncio.Future:
return asyncio.wrap_future(self.future())
def future(self) -> concurrent.futures.Future:
fut = concurrent.futures.Future()
def set_value(data: Any) -> None:
"""Schedules a callback to set the exception or result
in the Future."""
if isinstance(data, Exception):
fut.set_exception(data)
else:
fut.set_result(data)
self._on_completed(set_value)
# Prevent this object ref from being released.
fut.object_ref = self
return fut
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
"""Register a callback that will be called after Object is ready.
If the ObjectRef is already ready, the callback will be called soon.
The callback should take the result as the only argument. The result
can be an exception object in case of task error.
"""
from ray.util.client.client_pickler import loads_from_server
def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
"""Converts from a GetResponse proto to a python object."""
obj = resp.get
data = None
if not obj.valid:
data = loads_from_server(resp.get.error)
else:
data = loads_from_server(resp.get.data)
py_callback(data)
ray._register_callback(self, deserialize_obj)
class ClientActorRef(ClientBaseRef):
pass
class ClientStub:

View file

@ -4,6 +4,7 @@ from operator import getitem
import uuid
import ray
from ray.util.client.common import ClientObjectRef
from dask.base import quote
from dask.core import get as get_sync
@ -46,7 +47,7 @@ def unpack_object_refs(*args):
object_refs_token = uuid.uuid4().hex
def _unpack(expr):
if isinstance(expr, ray.ObjectRef):
if isinstance(expr, (ray.ObjectRef, ClientObjectRef)):
token = expr.hex()
repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
object_refs.append(expr)

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
import ray
from ray.util.client.common import ClientObjectRef
from dask.core import istask, ishashable, _execute_task
from dask.system import CPU_COUNT
@ -369,7 +370,8 @@ def ray_get_unpack(object_refs):
if isinstance(object_refs, tuple):
object_refs = list(object_refs)
if isinstance(object_refs, list) and any(not isinstance(x, ray.ObjectRef)
if isinstance(object_refs, list) and any(
not isinstance(x, (ray.ObjectRef, ClientObjectRef))
for x in object_refs):
# We flatten the object references before calling ray.get(), since Dask
# loves to nest collections in nested tuples and Ray expects a flat

View file

@ -3,6 +3,7 @@ from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from typing import TYPE_CHECKING
import ray
from ray._raylet import ObjectRef
@ -12,6 +13,9 @@ from ray.ray_constants import (to_memory_units, MEMORY_RESOURCE_UNIT_BYTES)
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import client_mode_wrap
if TYPE_CHECKING:
from ray.util.common import ClientObjectRef # noqa
bundle_reservation_check = None
@ -45,7 +49,7 @@ class PlacementGroup:
self.id = id
self.bundle_cache = bundle_cache
def ready(self) -> ObjectRef:
def ready(self) -> Union[ObjectRef, "ClientObjectRef"]:
"""Returns an ObjectRef to check ready status.
This API runs a small dummy task to wait for placement group creation.