mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Client][Proxy] Track Num Clients in the proxy (#16038)
This commit is contained in:
parent
7c3874b38e
commit
3dbdd4eb46
3 changed files with 61 additions and 7 deletions
|
@ -22,6 +22,7 @@ class ClientInfo:
|
|||
ray_version: str
|
||||
ray_commit: str
|
||||
protocol_version: str
|
||||
_num_clients: int
|
||||
|
||||
|
||||
class ClientBuilder:
|
||||
|
@ -57,7 +58,8 @@ class ClientBuilder:
|
|||
python_version=client_info_dict["python_version"],
|
||||
ray_version=client_info_dict["ray_version"],
|
||||
ray_commit=client_info_dict["ray_commit"],
|
||||
protocol_version=client_info_dict["protocol_version"])
|
||||
protocol_version=client_info_dict["protocol_version"],
|
||||
_num_clients=client_info_dict["num_clients"])
|
||||
|
||||
|
||||
class _LocalClientBuilder(ClientBuilder):
|
||||
|
|
|
@ -8,6 +8,7 @@ import grpc
|
|||
import ray
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.job_config import JobConfig
|
||||
from ray.test_utils import run_string_as_driver
|
||||
import ray.util.client.server.proxier as proxier
|
||||
|
||||
|
||||
|
@ -97,6 +98,32 @@ def test_multiple_clients_use_different_drivers(call_ray_start):
|
|||
assert namespace_one != namespace_two
|
||||
|
||||
|
||||
check_we_are_second = """
|
||||
import ray
|
||||
info = ray.client('localhost:25005').connect()
|
||||
assert info._num_clients == {num_clients}
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="PSUtil does not work the same on windows.")
|
||||
@pytest.mark.parametrize(
|
||||
"call_ray_start",
|
||||
["ray start --head --ray-client-server-port 25005 --port 0"],
|
||||
indirect=True)
|
||||
def test_correct_num_clients(call_ray_start):
|
||||
"""
|
||||
Checks that the returned value of `num_clients` correctly tracks clients
|
||||
connecting and disconnecting.
|
||||
"""
|
||||
info = ray.client("localhost:25005").connect()
|
||||
assert info._num_clients == 1
|
||||
run_string_as_driver(check_we_are_second.format(num_clients=2))
|
||||
ray.util.disconnect()
|
||||
run_string_as_driver(check_we_are_second.format(num_clients=1))
|
||||
|
||||
|
||||
def test_prepare_runtime_init_req_fails():
|
||||
"""
|
||||
Check that a connection that is initiated with a non-Init request
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import psutil
|
||||
import socket
|
||||
import sys
|
||||
from threading import Thread, RLock
|
||||
from threading import Lock, Thread, RLock
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
|
@ -314,8 +314,26 @@ def prepare_runtime_init_req(iterator: Iterator[ray_client_pb2.DataRequest]
|
|||
|
||||
class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, proxy_manager: ProxyManager):
|
||||
self.num_clients = 0
|
||||
self.clients_lock = Lock()
|
||||
self.proxy_manager = proxy_manager
|
||||
|
||||
def modify_connection_info_resp(self,
|
||||
init_resp: ray_client_pb2.DataResponse
|
||||
) -> ray_client_pb2.DataResponse:
|
||||
"""
|
||||
Modify the `num_clients` returned the ConnectionInfoResponse because
|
||||
individual SpecificServers only have **one** client.
|
||||
"""
|
||||
init_type = init_resp.WhichOneof("type")
|
||||
if init_type != "connection_info":
|
||||
return init_resp
|
||||
modified_resp = ray_client_pb2.DataResponse()
|
||||
modified_resp.CopyFrom(init_resp)
|
||||
with self.clients_lock:
|
||||
modified_resp.connection_info.num_clients = self.num_clients
|
||||
return modified_resp
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
client_id = _get_client_id_from_context(context)
|
||||
if client_id == "":
|
||||
|
@ -337,11 +355,18 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
return None
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||
new_iter = chain([modified_init_req], request_iterator)
|
||||
resp_stream = stub.Datapath(
|
||||
new_iter, metadata=[("client_id", client_id)])
|
||||
for resp in resp_stream:
|
||||
yield resp
|
||||
try:
|
||||
with self.clients_lock:
|
||||
self.num_clients += 1
|
||||
new_iter = chain([modified_init_req], request_iterator)
|
||||
resp_stream = stub.Datapath(
|
||||
new_iter, metadata=[("client_id", client_id)])
|
||||
for resp in resp_stream:
|
||||
yield self.modify_connection_info_resp(resp)
|
||||
finally:
|
||||
with self.clients_lock:
|
||||
logger.debug(f"Client detached: {client_id}")
|
||||
self.num_clients -= 1
|
||||
|
||||
|
||||
class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
|
|
Loading…
Add table
Reference in a new issue