mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Revert "[Client] Make Client{ObjectRef,ActorRef} subclasses of their server-side counterparts (#16110)" (#16196)
This reverts commit f14f197d42
.
This commit is contained in:
parent
611da62739
commit
9942505b63
9 changed files with 95 additions and 153 deletions
|
@ -99,7 +99,6 @@ sys.modules["pytorch_lightning"] = ChildClassMock()
|
|||
sys.modules["xgboost"] = ChildClassMock()
|
||||
sys.modules["xgboost.core"] = ChildClassMock()
|
||||
sys.modules["xgboost.callback"] = ChildClassMock()
|
||||
sys.modules["xgboost_ray"] = ChildClassMock()
|
||||
|
||||
|
||||
class SimpleClass(object):
|
||||
|
|
|
@ -7,8 +7,6 @@ import logging
|
|||
from typing import Callable, Any, Union
|
||||
|
||||
import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.util.client as client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -135,59 +133,3 @@ cdef class ObjectRef(BaseID):
|
|||
core_worker = ray.worker.global_worker.core_worker
|
||||
core_worker.set_get_async_callback(self, py_callback)
|
||||
return self
|
||||
|
||||
|
||||
cdef class ClientObjectRef(ObjectRef):
|
||||
|
||||
def __init__(self, id: bytes):
|
||||
check_id(id)
|
||||
self.data = CObjectID.FromBinary(<c_string>id)
|
||||
client.ray.call_retain(id)
|
||||
self.in_core_worker = False
|
||||
|
||||
def __dealloc__(self):
|
||||
if client.ray.is_connected() and not self.data.IsNil():
|
||||
client.ray.call_release(self.id)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.binary()
|
||||
|
||||
def future(self) -> concurrent.futures.Future:
|
||||
fut = concurrent.futures.Future()
|
||||
|
||||
def set_value(data: Any) -> None:
|
||||
"""Schedules a callback to set the exception or result
|
||||
in the Future."""
|
||||
|
||||
if isinstance(data, Exception):
|
||||
fut.set_exception(data)
|
||||
else:
|
||||
fut.set_result(data)
|
||||
|
||||
self._on_completed(set_value)
|
||||
|
||||
# Prevent this object ref from being released.
|
||||
fut.object_ref = self
|
||||
return fut
|
||||
|
||||
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
|
||||
"""Register a callback that will be called after Object is ready.
|
||||
If the ObjectRef is already ready, the callback will be called soon.
|
||||
The callback should take the result as the only argument. The result
|
||||
can be an exception object in case of task error.
|
||||
"""
|
||||
from ray.util.client.client_pickler import loads_from_server
|
||||
|
||||
def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
|
||||
"""Converts from a GetResponse proto to a python object."""
|
||||
obj = resp.get
|
||||
data = None
|
||||
if not obj.valid:
|
||||
data = loads_from_server(resp.get.error)
|
||||
else:
|
||||
data = loads_from_server(resp.get.data)
|
||||
|
||||
py_callback(data)
|
||||
|
||||
client.ray._register_callback(self, deserialize_obj)
|
||||
|
|
|
@ -259,7 +259,6 @@ cdef class WorkerID(UniqueID):
|
|||
return <CWorkerID>self.data
|
||||
|
||||
cdef class ActorID(BaseID):
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id, CActorID.Size())
|
||||
self.data = CActorID.FromBinary(<c_string>id)
|
||||
|
@ -303,22 +302,6 @@ cdef class ActorID(BaseID):
|
|||
return self.data.Hash()
|
||||
|
||||
|
||||
cdef class ClientActorRef(ActorID):
|
||||
|
||||
def __init__(self, id: bytes):
|
||||
check_id(id, CActorID.Size())
|
||||
self.data = CActorID.FromBinary(<c_string>id)
|
||||
client.ray.call_retain(id)
|
||||
|
||||
def __dealloc__(self):
|
||||
if client.ray.is_connected() and not self.data.IsNil():
|
||||
client.ray.call_release(self.id)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.binary()
|
||||
|
||||
|
||||
cdef class FunctionID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
|
|
|
@ -1,79 +1,11 @@
|
|||
import sys
|
||||
|
||||
import pytest
|
||||
from ray.util.client import RayAPIStub
|
||||
from ray.util.client.common import ClientActorRef, ClientObjectRef
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
from ray.util.client.ray_client_helpers import (
|
||||
ray_start_client_server_pair, ray_start_cluster_client_server_pair)
|
||||
from ray.test_utils import wait_for_condition
|
||||
import ray as real_ray
|
||||
from ray.core.generated.gcs_pb2 import ActorTableData
|
||||
from ray._raylet import ActorID, ObjectRef
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Timing out on Windows.")
|
||||
def test_client_object_ref_basics(ray_start_regular):
|
||||
with ray_start_client_server_pair() as pair:
|
||||
ray, server = pair
|
||||
ref = ray.put("Hello World")
|
||||
# Make sure ClientObjectRef is a subclass of ObjectRef
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
assert isinstance(ref, ObjectRef)
|
||||
|
||||
# Invalid ref format.
|
||||
with pytest.raises(Exception):
|
||||
ClientObjectRef(b"\0")
|
||||
|
||||
# Test __eq__()
|
||||
id = b"\0" * 28
|
||||
assert ClientObjectRef(id) == ClientObjectRef(id)
|
||||
assert ClientObjectRef(id) != ref
|
||||
assert ClientObjectRef(id) != ObjectRef(id)
|
||||
|
||||
assert ClientObjectRef(id).__repr__() == f"ClientObjectRef({id.hex()})"
|
||||
assert ClientObjectRef(id).binary() == id
|
||||
assert ClientObjectRef(id).hex() == id.hex()
|
||||
assert not ClientObjectRef(id).is_nil()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Timing out on Windows.")
|
||||
def test_client_actor_ref_basics(ray_start_regular):
|
||||
with ray_start_client_server_pair() as pair:
|
||||
ray, server = pair
|
||||
|
||||
@ray.remote
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.acc = 0
|
||||
|
||||
def inc(self):
|
||||
self.acc += 1
|
||||
|
||||
def get(self):
|
||||
return self.acc
|
||||
|
||||
counter = Counter.remote()
|
||||
ref = counter.actor_ref
|
||||
|
||||
# Make sure ClientActorRef is a subclass of ActorID
|
||||
assert isinstance(ref, ClientActorRef)
|
||||
assert isinstance(ref, ActorID)
|
||||
|
||||
# Invalid ref format.
|
||||
with pytest.raises(Exception):
|
||||
ClientActorRef(b"\0")
|
||||
|
||||
# Test __eq__()
|
||||
id = b"\0" * 16
|
||||
assert ClientActorRef(id) == ClientActorRef(id)
|
||||
assert ClientActorRef(id) != ref
|
||||
assert ClientActorRef(id) != ActorID(id)
|
||||
|
||||
assert ClientActorRef(id).__repr__() == f"ClientActorRef({id.hex()})"
|
||||
assert ClientActorRef(id).binary() == id
|
||||
assert ClientActorRef(id).hex() == id.hex()
|
||||
assert not ClientActorRef(id).is_nil()
|
||||
|
||||
|
||||
def server_object_ref_count(server, n):
|
||||
|
|
|
@ -19,7 +19,6 @@ For many of the objects in the root `ray` namespace, there is an equivalent clie
|
|||
These objects are client stand-ins for their server-side objects. For example:
|
||||
```
|
||||
ObjectRef <-> ClientObjectRef
|
||||
ActorID <-> ClientActorRef
|
||||
RemoteFunc <-> ClientRemoteFunc
|
||||
```
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import ray._raylet as raylet
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client import ray
|
||||
from ray.util.client.options import validate_options
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from dataclasses import dataclass
|
||||
import grpc
|
||||
import os
|
||||
|
@ -13,6 +14,7 @@ from ray.util.inspect import is_cython
|
|||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
@ -50,9 +52,87 @@ GRPC_OPTIONS = [
|
|||
CLIENT_SERVER_MAX_THREADS = float(
|
||||
os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))
|
||||
|
||||
# Aliases for compatibility.
|
||||
ClientObjectRef = raylet.ClientObjectRef
|
||||
ClientActorRef = raylet.ClientActorRef
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id: bytes):
|
||||
self.id = None
|
||||
if not isinstance(id, bytes):
|
||||
raise TypeError("ClientRefs must be created with bytes IDs")
|
||||
self.id: bytes = id
|
||||
ray.call_retain(id)
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
def hex(self):
|
||||
return self.id.hex()
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ClientBaseRef) and self.id == other.id
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
type(self).__name__,
|
||||
self.id.hex(),
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __del__(self):
|
||||
if ray.is_connected() and self.id is not None:
|
||||
ray.call_release(self.id)
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
def __await__(self):
|
||||
return self.as_future().__await__()
|
||||
|
||||
def as_future(self) -> asyncio.Future:
|
||||
return asyncio.wrap_future(self.future())
|
||||
|
||||
def future(self) -> concurrent.futures.Future:
|
||||
fut = concurrent.futures.Future()
|
||||
|
||||
def set_value(data: Any) -> None:
|
||||
"""Schedules a callback to set the exception or result
|
||||
in the Future."""
|
||||
|
||||
if isinstance(data, Exception):
|
||||
fut.set_exception(data)
|
||||
else:
|
||||
fut.set_result(data)
|
||||
|
||||
self._on_completed(set_value)
|
||||
|
||||
# Prevent this object ref from being released.
|
||||
fut.object_ref = self
|
||||
return fut
|
||||
|
||||
def _on_completed(self, py_callback: Callable[[Any], None]) -> None:
|
||||
"""Register a callback that will be called after Object is ready.
|
||||
If the ObjectRef is already ready, the callback will be called soon.
|
||||
The callback should take the result as the only argument. The result
|
||||
can be an exception object in case of task error.
|
||||
"""
|
||||
from ray.util.client.client_pickler import loads_from_server
|
||||
|
||||
def deserialize_obj(resp: ray_client_pb2.DataResponse) -> None:
|
||||
"""Converts from a GetResponse proto to a python object."""
|
||||
obj = resp.get
|
||||
data = None
|
||||
if not obj.valid:
|
||||
data = loads_from_server(resp.get.error)
|
||||
else:
|
||||
data = loads_from_server(resp.get.data)
|
||||
|
||||
py_callback(data)
|
||||
|
||||
ray._register_callback(self, deserialize_obj)
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
pass
|
||||
|
||||
|
||||
class ClientStub:
|
||||
|
|
|
@ -4,6 +4,7 @@ from operator import getitem
|
|||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
|
||||
from dask.base import quote
|
||||
from dask.core import get as get_sync
|
||||
|
@ -46,7 +47,7 @@ def unpack_object_refs(*args):
|
|||
object_refs_token = uuid.uuid4().hex
|
||||
|
||||
def _unpack(expr):
|
||||
if isinstance(expr, ray.ObjectRef):
|
||||
if isinstance(expr, (ray.ObjectRef, ClientObjectRef)):
|
||||
token = expr.hex()
|
||||
repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
|
||||
object_refs.append(expr)
|
||||
|
|
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
import ray
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
|
||||
from dask.core import istask, ishashable, _execute_task
|
||||
from dask.system import CPU_COUNT
|
||||
|
@ -369,7 +370,8 @@ def ray_get_unpack(object_refs):
|
|||
if isinstance(object_refs, tuple):
|
||||
object_refs = list(object_refs)
|
||||
|
||||
if isinstance(object_refs, list) and any(not isinstance(x, ray.ObjectRef)
|
||||
if isinstance(object_refs, list) and any(
|
||||
not isinstance(x, (ray.ObjectRef, ClientObjectRef))
|
||||
for x in object_refs):
|
||||
# We flatten the object references before calling ray.get(), since Dask
|
||||
# loves to nest collections in nested tuples and Ray expects a flat
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Dict
|
|||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
from ray._raylet import ObjectRef
|
||||
|
@ -12,6 +13,9 @@ from ray.ray_constants import (to_memory_units, MEMORY_RESOURCE_UNIT_BYTES)
|
|||
from ray._private.client_mode_hook import client_mode_should_convert
|
||||
from ray._private.client_mode_hook import client_mode_wrap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.common import ClientObjectRef # noqa
|
||||
|
||||
bundle_reservation_check = None
|
||||
|
||||
|
||||
|
@ -45,7 +49,7 @@ class PlacementGroup:
|
|||
self.id = id
|
||||
self.bundle_cache = bundle_cache
|
||||
|
||||
def ready(self) -> ObjectRef:
|
||||
def ready(self) -> Union[ObjectRef, "ClientObjectRef"]:
|
||||
"""Returns an ObjectRef to check ready status.
|
||||
|
||||
This API runs a small dummy task to wait for placement group creation.
|
||||
|
|
Loading…
Add table
Reference in a new issue