mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
This commit is contained in:
parent
af01a47d59
commit
0d0c2418b8
11 changed files with 1073 additions and 770 deletions
|
@ -69,6 +69,30 @@ def client_mode_should_convert():
|
|||
return client_mode_enabled and _client_hook_enabled
|
||||
|
||||
|
||||
def client_mode_wrap(func):
|
||||
"""Wraps a function called during client mode for execution as a remote
|
||||
task.
|
||||
|
||||
Can be used to implement public features of ray client which do not
|
||||
belong in the main ray API (`ray.*`), yet require server-side execution.
|
||||
An example is the creation of placement groups:
|
||||
`ray.util.placement_group.placement_group()`. When called on the client
|
||||
side, this function is wrapped in a task to facilitate interaction with
|
||||
the GCS.
|
||||
"""
|
||||
from ray.util.client import ray
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if client_mode_should_convert():
|
||||
f = ray.remote(num_cpus=0)(func)
|
||||
ref = f.remote(*args, **kwargs)
|
||||
return ray.get(ref)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
|
||||
"""Runs a preregistered ray RemoteFunction through the ray client.
|
||||
|
||||
|
@ -80,7 +104,10 @@ def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
|
|||
from ray.util.client import ray
|
||||
|
||||
key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None)
|
||||
if key is None:
|
||||
|
||||
# Second part of "or" is needed in case func_cls is reused between Ray
|
||||
# client sessions in one Python interpreter session.
|
||||
if (key is None) or (not ray._converted_key_exists(key)):
|
||||
key = ray._convert_function(func_cls)
|
||||
setattr(func_cls, RAY_CLIENT_MODE_ATTR, key)
|
||||
client_func = ray._get_converted(key)
|
||||
|
@ -98,7 +125,9 @@ def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs):
|
|||
from ray.util.client import ray
|
||||
|
||||
key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None)
|
||||
if key is None:
|
||||
# Second part of "or" is needed in case actor_cls is reused between Ray
|
||||
# client sessions in one Python interpreter session.
|
||||
if (key is None) or (not ray._converted_key_exists(key)):
|
||||
key = ray._convert_actor(actor_cls)
|
||||
setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key)
|
||||
client_actor = ray._get_converted(key)
|
||||
|
|
|
@ -148,6 +148,13 @@ def ray_start_cluster(request):
|
|||
yield res
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_cluster_init(request):
|
||||
param = getattr(request, "param", {})
|
||||
with _ray_start_cluster(do_init=True, **param) as res:
|
||||
yield res
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_cluster_head(request):
|
||||
param = getattr(request, "param", {})
|
||||
|
|
|
@ -7,9 +7,12 @@ import threading
|
|||
import _thread
|
||||
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
from ray.util.client.common import ClientObjectRef
|
||||
from ray.util.client.ray_client_helpers import connect_to_client_or_not
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
from ray._private.client_mode_hook import _explicitly_enable_client_mode
|
||||
from ray._private.client_mode_hook import client_mode_should_convert
|
||||
from ray._private.client_mode_hook import enable_client_mode
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
|
@ -179,6 +182,8 @@ def test_wait(ray_start_regular_shared):
|
|||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_remote_functions(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
|
@ -220,6 +225,18 @@ def test_remote_functions(ray_start_regular_shared):
|
|||
all_vals = ray.get(res[0])
|
||||
assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800]
|
||||
|
||||
# Timeout 0 on ray.wait leads to immediate return
|
||||
# (not indefinite wait for first return as with timeout None):
|
||||
unready_ref = signaler.wait.remote()
|
||||
res = ray.wait([unready_ref], timeout=0)
|
||||
# Not ready.
|
||||
assert res[0] == [] and len(res[1]) == 1
|
||||
ray.get(signaler.send.remote())
|
||||
ready_ref = signaler.wait.remote()
|
||||
# Ready.
|
||||
res = ray.wait([ready_ref], timeout=10)
|
||||
assert len(res[0]) == 1 and res[1] == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_function_calling_function(ray_start_regular_shared):
|
||||
|
@ -523,16 +540,16 @@ def test_client_gpu_ids(call_ray_stop_only):
|
|||
import ray
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
_explicitly_enable_client_mode()
|
||||
# No client connection.
|
||||
with pytest.raises(Exception) as e:
|
||||
ray.get_gpu_ids()
|
||||
assert str(e.value) == "Ray Client is not connected."\
|
||||
" Please connect by calling `ray.connect`."
|
||||
with enable_client_mode():
|
||||
# No client connection.
|
||||
with pytest.raises(Exception) as e:
|
||||
ray.get_gpu_ids()
|
||||
assert str(e.value) == "Ray Client is not connected."\
|
||||
" Please connect by calling `ray.connect`."
|
||||
|
||||
with ray_start_client_server():
|
||||
# Now have a client connection.
|
||||
assert ray.get_gpu_ids() == []
|
||||
with ray_start_client_server():
|
||||
# Now have a client connection.
|
||||
assert ray.get_gpu_ids() == []
|
||||
|
||||
|
||||
def test_client_serialize_addon(call_ray_stop_only):
|
||||
|
@ -548,5 +565,19 @@ def test_client_serialize_addon(call_ray_stop_only):
|
|||
assert ray.get(ray.put(User(name="ray"))).name == "ray"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connect_to_client", [False, True])
|
||||
def test_client_context_manager(ray_start_regular_shared, connect_to_client):
|
||||
import ray
|
||||
with connect_to_client_or_not(connect_to_client):
|
||||
if connect_to_client:
|
||||
# Client mode is on.
|
||||
assert client_mode_should_convert() is True
|
||||
# We're connected to Ray client.
|
||||
assert ray.util.client.ray.is_connected() is True
|
||||
else:
|
||||
assert client_mode_should_convert() is False
|
||||
assert ray.util.client.ray.is_connected() is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -76,19 +76,26 @@ def test_validate_port():
|
|||
|
||||
|
||||
def test_basic_preregister(init_and_serve):
|
||||
"""Tests conversion of Ray actors and remote functions to client actors
|
||||
and client remote functions.
|
||||
|
||||
Checks that the conversion works when disconnecting and reconnecting client
|
||||
sessions.
|
||||
"""
|
||||
from ray.util.client import ray
|
||||
ray.connect("localhost:50051")
|
||||
val = ray.get(hello_world.remote())
|
||||
print(val)
|
||||
assert val >= 20
|
||||
assert val <= 200
|
||||
c = C.remote(3)
|
||||
x = c.double.remote()
|
||||
y = c.double.remote()
|
||||
ray.wait([x, y])
|
||||
val = ray.get(c.get.remote())
|
||||
assert val == 12
|
||||
ray.disconnect()
|
||||
for _ in range(2):
|
||||
ray.connect("localhost:50051")
|
||||
val = ray.get(hello_world.remote())
|
||||
print(val)
|
||||
assert val >= 20
|
||||
assert val <= 200
|
||||
c = C.remote(3)
|
||||
x = c.double.remote()
|
||||
y = c.double.remote()
|
||||
ray.wait([x, y])
|
||||
val = ray.get(c.get.remote())
|
||||
assert val == 12
|
||||
ray.disconnect()
|
||||
|
||||
|
||||
def test_num_clients(init_and_serve_lazy):
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -287,6 +287,10 @@ class ClientAPI:
|
|||
"""Given a UUID, return the converted object"""
|
||||
return self.worker._get_converted(key)
|
||||
|
||||
def _converted_key_exists(self, key: str) -> bool:
|
||||
"""Check if a key UUID is present in the store of converted objects."""
|
||||
return self.worker._converted_key_exists(key)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if not key.startswith("_"):
|
||||
raise NotImplementedError(
|
||||
|
|
|
@ -307,6 +307,12 @@ def set_task_options(task: ray_client_pb2.ClientTask,
|
|||
if options is None:
|
||||
task.ClearField(field)
|
||||
return
|
||||
|
||||
# If there's a non-null "placement_group" in `options`, convert the
|
||||
# placement group to a dict so that `options` can be passed to json.dumps.
|
||||
if options.get("placement_group", None):
|
||||
options["placement_group"] = options["placement_group"].to_dict()
|
||||
|
||||
options_str = json.dumps(options)
|
||||
getattr(task, field).json_options = options_str
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import ray as real_ray
|
|||
import ray.cloudpickle as pickle
|
||||
import ray.util.client.server.server as ray_client_server
|
||||
from ray.util.client import ray
|
||||
from ray._private.client_mode_hook import enable_client_mode
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -62,3 +63,36 @@ class RayClientSerializationContext:
|
|||
|
||||
# construct a reducer
|
||||
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connect_to_client_or_not(connect_to_client: bool):
|
||||
"""Utility for running test logic with and without a Ray client connection.
|
||||
|
||||
If client_connect is True, will connect to Ray client in context.
|
||||
If client_connect is False, does nothing.
|
||||
|
||||
How to use:
|
||||
Given a test of the following form:
|
||||
|
||||
def test_<name>(args):
|
||||
<initialize a ray cluster>
|
||||
<use the ray cluster>
|
||||
|
||||
Modify the test to
|
||||
|
||||
@pytest.mark.parametrize("connect_to_client", [False, True])
|
||||
def test_<name>(args, connect_to_client)
|
||||
<initialize a ray cluster>
|
||||
with connect_to_client_or_not(connect_to_client):
|
||||
<use the ray cluster>
|
||||
|
||||
Parameterize the argument connect over True, False to run the test with and
|
||||
without a Ray client connection.
|
||||
"""
|
||||
|
||||
if connect_to_client:
|
||||
with ray_start_client_server(), enable_client_mode():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
|
|
@ -31,6 +31,7 @@ from ray.util.client.server.server_pickler import loads_from_client
|
|||
from ray.util.client.server.dataservicer import DataServicer
|
||||
from ray.util.client.server.logservicer import LogstreamServicer
|
||||
from ray.util.client.server.server_stubs import current_server
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -510,6 +511,13 @@ def decode_options(
|
|||
return None
|
||||
opts = json.loads(options.json_options)
|
||||
assert isinstance(opts, dict)
|
||||
|
||||
if opts.get("placement_group", None):
|
||||
# Placement groups in Ray client options are serialized as dicts.
|
||||
# Convert the dict to a PlacementGroup.
|
||||
opts["placement_group"] = PlacementGroup.from_dict(
|
||||
opts["placement_group"])
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
|
|
|
@ -274,7 +274,7 @@ class Worker:
|
|||
data = {
|
||||
"object_ids": [object_ref.id for object_ref in object_refs],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1,
|
||||
"timeout": timeout if (timeout is not None) else -1,
|
||||
"client_id": self._client_id,
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
|
@ -512,6 +512,10 @@ class Worker:
|
|||
"""Given a UUID, return the converted object"""
|
||||
return self._converted[key]
|
||||
|
||||
def _converted_key_exists(self, key: str) -> bool:
|
||||
"""Check if a key UUID is present in the store of converted objects."""
|
||||
return key in self._converted
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
|
|
|
@ -1,11 +1,20 @@
|
|||
import time
|
||||
|
||||
from typing import (List, Dict, Optional, Union)
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
from ray._raylet import PlacementGroupID, ObjectRef
|
||||
from ray._raylet import ObjectRef
|
||||
from ray._raylet import PlacementGroupID
|
||||
from ray._private.utils import hex_to_binary
|
||||
from ray.ray_constants import (to_memory_units, MEMORY_RESOURCE_UNIT_BYTES)
|
||||
from ray._private.client_mode_hook import client_mode_should_convert
|
||||
from ray._private.client_mode_hook import client_mode_wrap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.common import ClientObjectRef # noqa
|
||||
|
||||
bundle_reservation_check = None
|
||||
|
||||
|
@ -31,14 +40,16 @@ class PlacementGroup:
|
|||
"""A handle to a placement group."""
|
||||
|
||||
@staticmethod
|
||||
def empty():
|
||||
def empty() -> "PlacementGroup":
|
||||
return PlacementGroup(PlacementGroupID.nil())
|
||||
|
||||
def __init__(self, id: PlacementGroupID):
|
||||
def __init__(self,
|
||||
id: PlacementGroupID,
|
||||
bundle_cache: Optional[List[Dict]] = None):
|
||||
self.id = id
|
||||
self.bundle_cache = None
|
||||
self.bundle_cache = bundle_cache
|
||||
|
||||
def ready(self) -> ObjectRef:
|
||||
def ready(self) -> Union[ObjectRef, "ClientObjectRef"]:
|
||||
"""Returns an ObjectRef to check ready status.
|
||||
|
||||
This API runs a small dummy task to wait for placement group creation.
|
||||
|
@ -67,7 +78,7 @@ class PlacementGroup:
|
|||
bundle_index = 0
|
||||
bundle = self.bundle_cache[bundle_index]
|
||||
|
||||
resource_name, value = self._get_none_zero_resource(bundle)
|
||||
resource_name, value = self._get_a_non_zero_resource(bundle)
|
||||
num_cpus = 0
|
||||
num_gpus = 0
|
||||
memory = 0
|
||||
|
@ -96,11 +107,7 @@ class PlacementGroup:
|
|||
Return:
|
||||
True if the placement group is created. False otherwise.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
return worker.core_worker.wait_placement_group_ready(
|
||||
self.id, timeout_seconds)
|
||||
return _call_placement_group_ready(self.id, timeout_seconds)
|
||||
|
||||
@property
|
||||
def bundle_specs(self) -> List[Dict]:
|
||||
|
@ -109,12 +116,54 @@ class PlacementGroup:
|
|||
return self.bundle_cache
|
||||
|
||||
@property
|
||||
def bundle_count(self):
|
||||
def bundle_count(self) -> int:
|
||||
self._fill_bundle_cache_if_needed()
|
||||
return len(self.bundle_cache)
|
||||
|
||||
def _get_none_zero_resource(self, bundle: List[Dict]):
|
||||
# Set a mock value to schedule a dummy task.
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert this placement group into a dict for purposes of json
|
||||
serialization.
|
||||
|
||||
Used when passing a placement group as an option to a Ray client remote
|
||||
function. See set_task_options in util/client/common.py.
|
||||
|
||||
Return:
|
||||
Dictionary with json-serializable keys representing the placemnent
|
||||
group.
|
||||
"""
|
||||
# Placement group id is converted to a hex /string/ to make it
|
||||
# serializable.
|
||||
return {"id": self.id.hex(), "bundle_cache": self.bundle_cache}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(pg_dict: dict) -> "PlacementGroup":
|
||||
"""Instantiate and return a PlacementGroup from its json-serializable
|
||||
dict representation.
|
||||
|
||||
Used by Ray Client on server-side to deserialize placement group
|
||||
option. See decode_options in util/client/server/server.py.
|
||||
|
||||
Args:
|
||||
serializable_form(dict): Dictionary representing a placement group.
|
||||
Return:
|
||||
A placement group made from the data in the input dict.
|
||||
"""
|
||||
# Validate serialized dict
|
||||
assert isinstance(pg_dict, dict)
|
||||
assert pg_dict.keys() == {"id", "bundle_cache"}
|
||||
# The value associated to key "id" is a hex string.
|
||||
assert isinstance(pg_dict["id"], str)
|
||||
if pg_dict["bundle_cache"] is not None:
|
||||
assert isinstance(pg_dict["bundle_cache"], list)
|
||||
|
||||
# Deserialize and return a Placement Group.
|
||||
id_bytes = bytes.fromhex(pg_dict["id"])
|
||||
pg_id = PlacementGroupID(id_bytes)
|
||||
bundle_cache = pg_dict["bundle_cache"]
|
||||
return PlacementGroup(pg_id, bundle_cache)
|
||||
|
||||
def _get_a_non_zero_resource(self, bundle: Dict):
|
||||
# Any number between 0-1 should be fine.
|
||||
MOCK_VALUE = 0.001
|
||||
for key, value in bundle.items():
|
||||
if value > 0:
|
||||
|
@ -123,31 +172,46 @@ class PlacementGroup:
|
|||
return key, value
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
def _fill_bundle_cache_if_needed(self):
|
||||
def _fill_bundle_cache_if_needed(self) -> None:
|
||||
if not self.bundle_cache:
|
||||
# Since creating placement group is async, it is
|
||||
# possible table is not ready yet. To avoid the
|
||||
# problem, we should keep trying with timeout.
|
||||
TIMEOUT_SECOND = 30
|
||||
WAIT_INTERVAL = 0.05
|
||||
timeout_cnt = 0
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
|
||||
pg_info = ray.state.state.placement_group_table(self.id)
|
||||
if pg_info:
|
||||
self.bundle_cache = list(pg_info["bundles"].values())
|
||||
return
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
timeout_cnt += 1
|
||||
|
||||
raise RuntimeError(
|
||||
"Couldn't get the bundle information of placement group id "
|
||||
f"{self.id} in {TIMEOUT_SECOND} seconds. It is likely "
|
||||
"because GCS server is too busy.")
|
||||
self.bundle_cache = _get_bundle_cache(self.id)
|
||||
|
||||
|
||||
@client_mode_wrap
|
||||
def _call_placement_group_ready(pg_id: PlacementGroupID,
|
||||
timeout_seconds: int) -> bool:
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
return worker.core_worker.wait_placement_group_ready(
|
||||
pg_id, timeout_seconds)
|
||||
|
||||
|
||||
@client_mode_wrap
|
||||
def _get_bundle_cache(pg_id: PlacementGroupID) -> List[Dict]:
|
||||
# Since creating placement group is async, it is
|
||||
# possible table is not ready yet. To avoid the
|
||||
# problem, we should keep trying with timeout.
|
||||
TIMEOUT_SECOND = 30
|
||||
WAIT_INTERVAL = 0.05
|
||||
timeout_cnt = 0
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
|
||||
while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
|
||||
pg_info = ray.state.state.placement_group_table(pg_id)
|
||||
if pg_info:
|
||||
return list(pg_info["bundles"].values())
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
timeout_cnt += 1
|
||||
|
||||
raise RuntimeError(
|
||||
"Couldn't get the bundle information of placement group id "
|
||||
f"{id} in {TIMEOUT_SECOND} seconds. It is likely "
|
||||
"because GCS server is too busy.")
|
||||
|
||||
|
||||
@client_mode_wrap
|
||||
def placement_group(bundles: List[Dict[str, float]],
|
||||
strategy: str = "PACK",
|
||||
name: str = "",
|
||||
|
@ -208,7 +272,8 @@ def placement_group(bundles: List[Dict[str, float]],
|
|||
return PlacementGroup(placement_group_id)
|
||||
|
||||
|
||||
def remove_placement_group(placement_group: PlacementGroup):
|
||||
@client_mode_wrap
|
||||
def remove_placement_group(placement_group: PlacementGroup) -> None:
|
||||
"""Asynchronously remove placement group.
|
||||
|
||||
Args:
|
||||
|
@ -221,7 +286,8 @@ def remove_placement_group(placement_group: PlacementGroup):
|
|||
worker.core_worker.remove_placement_group(placement_group.id)
|
||||
|
||||
|
||||
def get_placement_group(placement_group_name: str):
|
||||
@client_mode_wrap
|
||||
def get_placement_group(placement_group_name: str) -> PlacementGroup:
|
||||
"""Get a placement group object with a global name.
|
||||
|
||||
Returns:
|
||||
|
@ -244,6 +310,7 @@ def get_placement_group(placement_group_name: str):
|
|||
hex_to_binary(placement_group_info["placement_group_id"])))
|
||||
|
||||
|
||||
@client_mode_wrap
|
||||
def placement_group_table(placement_group: PlacementGroup = None) -> dict:
|
||||
"""Get the state of the placement group from GCS.
|
||||
|
||||
|
@ -286,6 +353,9 @@ def get_current_placement_group() -> Optional[PlacementGroup]:
|
|||
None if the current task or actor wasn't
|
||||
created with any placement group.
|
||||
"""
|
||||
if client_mode_should_convert():
|
||||
# Client mode is only a driver.
|
||||
return None
|
||||
worker = ray.worker.global_worker
|
||||
worker.check_connected()
|
||||
pg_id = worker.placement_group_id
|
||||
|
@ -295,7 +365,7 @@ def get_current_placement_group() -> Optional[PlacementGroup]:
|
|||
|
||||
|
||||
def check_placement_group_index(placement_group: PlacementGroup,
|
||||
bundle_index: int):
|
||||
bundle_index: int) -> None:
|
||||
assert placement_group is not None
|
||||
if placement_group.id.is_nil():
|
||||
if bundle_index != -1:
|
||||
|
|
Loading…
Add table
Reference in a new issue