[client][placement groups] Client placement group hooks, attempt #3 (#15382)

This commit is contained in:
Dmitri Gekhtman 2021-04-22 20:18:55 -04:00 committed by GitHub
parent af01a47d59
commit 0d0c2418b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1073 additions and 770 deletions

View file

@ -69,6 +69,30 @@ def client_mode_should_convert():
return client_mode_enabled and _client_hook_enabled
def client_mode_wrap(func):
"""Wraps a function called during client mode for execution as a remote
task.
Can be used to implement public features of ray client which do not
belong in the main ray API (`ray.*`), yet require server-side execution.
An example is the creation of placement groups:
`ray.util.placement_group.placement_group()`. When called on the client
side, this function is wrapped in a task to facilitate interaction with
the GCS.
"""
from ray.util.client import ray
@wraps(func)
def wrapper(*args, **kwargs):
if client_mode_should_convert():
f = ray.remote(num_cpus=0)(func)
ref = f.remote(*args, **kwargs)
return ray.get(ref)
return func(*args, **kwargs)
return wrapper
def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
"""Runs a preregistered ray RemoteFunction through the ray client.
@ -80,7 +104,10 @@ def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
from ray.util.client import ray
key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None)
if key is None:
# Second part of "or" is needed in case func_cls is reused between Ray
# client sessions in one Python interpreter session.
if (key is None) or (not ray._converted_key_exists(key)):
key = ray._convert_function(func_cls)
setattr(func_cls, RAY_CLIENT_MODE_ATTR, key)
client_func = ray._get_converted(key)
@ -98,7 +125,9 @@ def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs):
from ray.util.client import ray
key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None)
if key is None:
# Second part of "or" is needed in case actor_cls is reused between Ray
# client sessions in one Python interpreter session.
if (key is None) or (not ray._converted_key_exists(key)):
key = ray._convert_actor(actor_cls)
setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key)
client_actor = ray._get_converted(key)

View file

@ -148,6 +148,13 @@ def ray_start_cluster(request):
yield res
@pytest.fixture
def ray_start_cluster_init(request):
param = getattr(request, "param", {})
with _ray_start_cluster(do_init=True, **param) as res:
yield res
@pytest.fixture
def ray_start_cluster_head(request):
param = getattr(request, "param", {})

View file

@ -7,9 +7,12 @@ import threading
import _thread
import ray.util.client.server.server as ray_client_server
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.util.client.common import ClientObjectRef
from ray.util.client.ray_client_helpers import connect_to_client_or_not
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import _explicitly_enable_client_mode
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import enable_client_mode
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
@ -179,6 +182,8 @@ def test_wait(ray_start_regular_shared):
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_remote_functions(ray_start_regular_shared):
with ray_start_client_server() as ray:
SignalActor = create_remote_signal_actor(ray)
signaler = SignalActor.remote()
@ray.remote
def plus2(x):
@ -220,6 +225,18 @@ def test_remote_functions(ray_start_regular_shared):
all_vals = ray.get(res[0])
assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800]
# Timeout 0 on ray.wait leads to immediate return
# (not indefinite wait for first return as with timeout None):
unready_ref = signaler.wait.remote()
res = ray.wait([unready_ref], timeout=0)
# Not ready.
assert res[0] == [] and len(res[1]) == 1
ray.get(signaler.send.remote())
ready_ref = signaler.wait.remote()
# Ready.
res = ray.wait([ready_ref], timeout=10)
assert len(res[0]) == 1 and res[1] == []
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_function_calling_function(ray_start_regular_shared):
@ -523,16 +540,16 @@ def test_client_gpu_ids(call_ray_stop_only):
import ray
ray.init(num_cpus=2)
_explicitly_enable_client_mode()
# No client connection.
with pytest.raises(Exception) as e:
ray.get_gpu_ids()
assert str(e.value) == "Ray Client is not connected."\
" Please connect by calling `ray.connect`."
with enable_client_mode():
# No client connection.
with pytest.raises(Exception) as e:
ray.get_gpu_ids()
assert str(e.value) == "Ray Client is not connected."\
" Please connect by calling `ray.connect`."
with ray_start_client_server():
# Now have a client connection.
assert ray.get_gpu_ids() == []
with ray_start_client_server():
# Now have a client connection.
assert ray.get_gpu_ids() == []
def test_client_serialize_addon(call_ray_stop_only):
@ -548,5 +565,19 @@ def test_client_serialize_addon(call_ray_stop_only):
assert ray.get(ray.put(User(name="ray"))).name == "ray"
@pytest.mark.parametrize("connect_to_client", [False, True])
def test_client_context_manager(ray_start_regular_shared, connect_to_client):
import ray
with connect_to_client_or_not(connect_to_client):
if connect_to_client:
# Client mode is on.
assert client_mode_should_convert() is True
# We're connected to Ray client.
assert ray.util.client.ray.is_connected() is True
else:
assert client_mode_should_convert() is False
assert ray.util.client.ray.is_connected() is False
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -76,19 +76,26 @@ def test_validate_port():
def test_basic_preregister(init_and_serve):
"""Tests conversion of Ray actors and remote functions to client actors
and client remote functions.
Checks that the conversion works when disconnecting and reconnecting client
sessions.
"""
from ray.util.client import ray
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()
for _ in range(2):
ray.connect("localhost:50051")
val = ray.get(hello_world.remote())
print(val)
assert val >= 20
assert val <= 200
c = C.remote(3)
x = c.double.remote()
y = c.double.remote()
ray.wait([x, y])
val = ray.get(c.get.remote())
assert val == 12
ray.disconnect()
def test_num_clients(init_and_serve_lazy):

File diff suppressed because it is too large Load diff

View file

@ -287,6 +287,10 @@ class ClientAPI:
"""Given a UUID, return the converted object"""
return self.worker._get_converted(key)
def _converted_key_exists(self, key: str) -> bool:
"""Check if a key UUID is present in the store of converted objects."""
return self.worker._converted_key_exists(key)
def __getattr__(self, key: str):
if not key.startswith("_"):
raise NotImplementedError(

View file

@ -307,6 +307,12 @@ def set_task_options(task: ray_client_pb2.ClientTask,
if options is None:
task.ClearField(field)
return
# If there's a non-null "placement_group" in `options`, convert the
# placement group to a dict so that `options` can be passed to json.dumps.
if options.get("placement_group", None):
options["placement_group"] = options["placement_group"].to_dict()
options_str = json.dumps(options)
getattr(task, field).json_options = options_str

View file

@ -4,6 +4,7 @@ import ray as real_ray
import ray.cloudpickle as pickle
import ray.util.client.server.server as ray_client_server
from ray.util.client import ray
from ray._private.client_mode_hook import enable_client_mode
@contextmanager
@ -62,3 +63,36 @@ class RayClientSerializationContext:
# construct a reducer
pickle.CloudPickler.dispatch[cls] = _CloudPicklerReducer
@contextmanager
def connect_to_client_or_not(connect_to_client: bool):
"""Utility for running test logic with and without a Ray client connection.
If client_connect is True, will connect to Ray client in context.
If client_connect is False, does nothing.
How to use:
Given a test of the following form:
def test_<name>(args):
<initialize a ray cluster>
<use the ray cluster>
Modify the test to
@pytest.mark.parametrize("connect_to_client", [False, True])
def test_<name>(args, connect_to_client)
<initialize a ray cluster>
with connect_to_client_or_not(connect_to_client):
<use the ray cluster>
Parameterize the argument connect over True, False to run the test with and
without a Ray client connection.
"""
if connect_to_client:
with ray_start_client_server(), enable_client_mode():
yield
else:
yield

View file

@ -31,6 +31,7 @@ from ray.util.client.server.server_pickler import loads_from_client
from ray.util.client.server.dataservicer import DataServicer
from ray.util.client.server.logservicer import LogstreamServicer
from ray.util.client.server.server_stubs import current_server
from ray.util.placement_group import PlacementGroup
from ray._private.client_mode_hook import disable_client_hook
logger = logging.getLogger(__name__)
@ -510,6 +511,13 @@ def decode_options(
return None
opts = json.loads(options.json_options)
assert isinstance(opts, dict)
if opts.get("placement_group", None):
# Placement groups in Ray client options are serialized as dicts.
# Convert the dict to a PlacementGroup.
opts["placement_group"] = PlacementGroup.from_dict(
opts["placement_group"])
return opts

View file

@ -274,7 +274,7 @@ class Worker:
data = {
"object_ids": [object_ref.id for object_ref in object_refs],
"num_returns": num_returns,
"timeout": timeout if timeout else -1,
"timeout": timeout if (timeout is not None) else -1,
"client_id": self._client_id,
}
req = ray_client_pb2.WaitRequest(**data)
@ -512,6 +512,10 @@ class Worker:
"""Given a UUID, return the converted object"""
return self._converted[key]
def _converted_key_exists(self, key: str) -> bool:
"""Check if a key UUID is present in the store of converted objects."""
return key in self._converted
def make_client_id() -> str:
id = uuid.uuid4()

View file

@ -1,11 +1,20 @@
import time
from typing import (List, Dict, Optional, Union)
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from typing import TYPE_CHECKING
import ray
from ray._raylet import PlacementGroupID, ObjectRef
from ray._raylet import ObjectRef
from ray._raylet import PlacementGroupID
from ray._private.utils import hex_to_binary
from ray.ray_constants import (to_memory_units, MEMORY_RESOURCE_UNIT_BYTES)
from ray._private.client_mode_hook import client_mode_should_convert
from ray._private.client_mode_hook import client_mode_wrap
if TYPE_CHECKING:
from ray.util.common import ClientObjectRef # noqa
bundle_reservation_check = None
@ -31,14 +40,16 @@ class PlacementGroup:
"""A handle to a placement group."""
@staticmethod
def empty():
def empty() -> "PlacementGroup":
return PlacementGroup(PlacementGroupID.nil())
def __init__(self, id: PlacementGroupID):
def __init__(self,
id: PlacementGroupID,
bundle_cache: Optional[List[Dict]] = None):
self.id = id
self.bundle_cache = None
self.bundle_cache = bundle_cache
def ready(self) -> ObjectRef:
def ready(self) -> Union[ObjectRef, "ClientObjectRef"]:
"""Returns an ObjectRef to check ready status.
This API runs a small dummy task to wait for placement group creation.
@ -67,7 +78,7 @@ class PlacementGroup:
bundle_index = 0
bundle = self.bundle_cache[bundle_index]
resource_name, value = self._get_none_zero_resource(bundle)
resource_name, value = self._get_a_non_zero_resource(bundle)
num_cpus = 0
num_gpus = 0
memory = 0
@ -96,11 +107,7 @@ class PlacementGroup:
Return:
True if the placement group is created. False otherwise.
"""
worker = ray.worker.global_worker
worker.check_connected()
return worker.core_worker.wait_placement_group_ready(
self.id, timeout_seconds)
return _call_placement_group_ready(self.id, timeout_seconds)
@property
def bundle_specs(self) -> List[Dict]:
@ -109,12 +116,54 @@ class PlacementGroup:
return self.bundle_cache
@property
def bundle_count(self):
def bundle_count(self) -> int:
self._fill_bundle_cache_if_needed()
return len(self.bundle_cache)
def _get_none_zero_resource(self, bundle: List[Dict]):
# Set a mock value to schedule a dummy task.
def to_dict(self) -> dict:
"""Convert this placement group into a dict for purposes of json
serialization.
Used when passing a placement group as an option to a Ray client remote
function. See set_task_options in util/client/common.py.
Return:
Dictionary with json-serializable keys representing the placemnent
group.
"""
# Placement group id is converted to a hex /string/ to make it
# serializable.
return {"id": self.id.hex(), "bundle_cache": self.bundle_cache}
@staticmethod
def from_dict(pg_dict: dict) -> "PlacementGroup":
"""Instantiate and return a PlacementGroup from its json-serializable
dict representation.
Used by Ray Client on server-side to deserialize placement group
option. See decode_options in util/client/server/server.py.
Args:
serializable_form(dict): Dictionary representing a placement group.
Return:
A placement group made from the data in the input dict.
"""
# Validate serialized dict
assert isinstance(pg_dict, dict)
assert pg_dict.keys() == {"id", "bundle_cache"}
# The value associated to key "id" is a hex string.
assert isinstance(pg_dict["id"], str)
if pg_dict["bundle_cache"] is not None:
assert isinstance(pg_dict["bundle_cache"], list)
# Deserialize and return a Placement Group.
id_bytes = bytes.fromhex(pg_dict["id"])
pg_id = PlacementGroupID(id_bytes)
bundle_cache = pg_dict["bundle_cache"]
return PlacementGroup(pg_id, bundle_cache)
def _get_a_non_zero_resource(self, bundle: Dict):
# Any number between 0-1 should be fine.
MOCK_VALUE = 0.001
for key, value in bundle.items():
if value > 0:
@ -123,31 +172,46 @@ class PlacementGroup:
return key, value
assert False, "This code should be unreachable."
def _fill_bundle_cache_if_needed(self):
def _fill_bundle_cache_if_needed(self) -> None:
if not self.bundle_cache:
# Since creating placement group is async, it is
# possible table is not ready yet. To avoid the
# problem, we should keep trying with timeout.
TIMEOUT_SECOND = 30
WAIT_INTERVAL = 0.05
timeout_cnt = 0
worker = ray.worker.global_worker
worker.check_connected()
while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
pg_info = ray.state.state.placement_group_table(self.id)
if pg_info:
self.bundle_cache = list(pg_info["bundles"].values())
return
time.sleep(WAIT_INTERVAL)
timeout_cnt += 1
raise RuntimeError(
"Couldn't get the bundle information of placement group id "
f"{self.id} in {TIMEOUT_SECOND} seconds. It is likely "
"because GCS server is too busy.")
self.bundle_cache = _get_bundle_cache(self.id)
@client_mode_wrap
def _call_placement_group_ready(pg_id: PlacementGroupID,
timeout_seconds: int) -> bool:
worker = ray.worker.global_worker
worker.check_connected()
return worker.core_worker.wait_placement_group_ready(
pg_id, timeout_seconds)
@client_mode_wrap
def _get_bundle_cache(pg_id: PlacementGroupID) -> List[Dict]:
# Since creating placement group is async, it is
# possible table is not ready yet. To avoid the
# problem, we should keep trying with timeout.
TIMEOUT_SECOND = 30
WAIT_INTERVAL = 0.05
timeout_cnt = 0
worker = ray.worker.global_worker
worker.check_connected()
while timeout_cnt < int(TIMEOUT_SECOND / WAIT_INTERVAL):
pg_info = ray.state.state.placement_group_table(pg_id)
if pg_info:
return list(pg_info["bundles"].values())
time.sleep(WAIT_INTERVAL)
timeout_cnt += 1
raise RuntimeError(
"Couldn't get the bundle information of placement group id "
f"{id} in {TIMEOUT_SECOND} seconds. It is likely "
"because GCS server is too busy.")
@client_mode_wrap
def placement_group(bundles: List[Dict[str, float]],
strategy: str = "PACK",
name: str = "",
@ -208,7 +272,8 @@ def placement_group(bundles: List[Dict[str, float]],
return PlacementGroup(placement_group_id)
def remove_placement_group(placement_group: PlacementGroup):
@client_mode_wrap
def remove_placement_group(placement_group: PlacementGroup) -> None:
"""Asynchronously remove placement group.
Args:
@ -221,7 +286,8 @@ def remove_placement_group(placement_group: PlacementGroup):
worker.core_worker.remove_placement_group(placement_group.id)
def get_placement_group(placement_group_name: str):
@client_mode_wrap
def get_placement_group(placement_group_name: str) -> PlacementGroup:
"""Get a placement group object with a global name.
Returns:
@ -244,6 +310,7 @@ def get_placement_group(placement_group_name: str):
hex_to_binary(placement_group_info["placement_group_id"])))
@client_mode_wrap
def placement_group_table(placement_group: PlacementGroup = None) -> dict:
"""Get the state of the placement group from GCS.
@ -286,6 +353,9 @@ def get_current_placement_group() -> Optional[PlacementGroup]:
None if the current task or actor wasn't
created with any placement group.
"""
if client_mode_should_convert():
# Client mode is only a driver.
return None
worker = ray.worker.global_worker
worker.check_connected()
pg_id = worker.placement_group_id
@ -295,7 +365,7 @@ def get_current_placement_group() -> Optional[PlacementGroup]:
def check_placement_group_index(placement_group: PlacementGroup,
bundle_index: int):
bundle_index: int) -> None:
assert placement_group is not None
if placement_group.id.is_nil():
if bundle_index != -1: