diff --git a/python/ray/client_builder.py b/python/ray/client_builder.py index 6553856e2..fdcef2cda 100644 --- a/python/ray/client_builder.py +++ b/python/ray/client_builder.py @@ -22,6 +22,7 @@ class ClientInfo: ray_version: str ray_commit: str protocol_version: str + _num_clients: int class ClientBuilder: @@ -57,7 +58,8 @@ class ClientBuilder: python_version=client_info_dict["python_version"], ray_version=client_info_dict["ray_version"], ray_commit=client_info_dict["ray_commit"], - protocol_version=client_info_dict["protocol_version"]) + protocol_version=client_info_dict["protocol_version"], + _num_clients=client_info_dict["num_clients"]) class _LocalClientBuilder(ClientBuilder): diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 0c728a375..9f9fd512e 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -8,6 +8,7 @@ import grpc import ray import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.job_config import JobConfig +from ray.test_utils import run_string_as_driver import ray.util.client.server.proxier as proxier @@ -97,6 +98,32 @@ def test_multiple_clients_use_different_drivers(call_ray_start): assert namespace_one != namespace_two +check_we_are_second = """ +import ray +info = ray.client('localhost:25005').connect() +assert info._num_clients == {num_clients} +""" + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="PSUtil does not work the same on windows.") +@pytest.mark.parametrize( + "call_ray_start", + ["ray start --head --ray-client-server-port 25005 --port 0"], + indirect=True) +def test_correct_num_clients(call_ray_start): + """ + Checks that the returned value of `num_clients` correctly tracks clients + connecting and disconnecting. + """ + info = ray.client("localhost:25005").connect() + assert info._num_clients == 1 + run_string_as_driver(check_we_are_second.format(num_clients=2)) + ray.util.disconnect() + run_string_as_driver(check_we_are_second.format(num_clients=1)) + + def test_prepare_runtime_init_req_fails(): """ Check that a connection that is initiated with a non-Init request diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 95d86f4fa..f668bf7ca 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -8,7 +8,7 @@ import json import psutil import socket import sys -from threading import Thread, RLock +from threading import Lock, Thread, RLock import time from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple @@ -314,8 +314,26 @@ def prepare_runtime_init_req(iterator: Iterator[ray_client_pb2.DataRequest] class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): def __init__(self, proxy_manager: ProxyManager): + self.num_clients = 0 + self.clients_lock = Lock() self.proxy_manager = proxy_manager + def modify_connection_info_resp(self, + init_resp: ray_client_pb2.DataResponse + ) -> ray_client_pb2.DataResponse: + """ + Modify the `num_clients` returned the ConnectionInfoResponse because + individual SpecificServers only have **one** client. + """ + init_type = init_resp.WhichOneof("type") + if init_type != "connection_info": + return init_resp + modified_resp = ray_client_pb2.DataResponse() + modified_resp.CopyFrom(init_resp) + with self.clients_lock: + modified_resp.connection_info.num_clients = self.num_clients + return modified_resp + def Datapath(self, request_iterator, context): client_id = _get_client_id_from_context(context) if client_id == "": @@ -337,11 +355,18 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): context.set_code(grpc.StatusCode.NOT_FOUND) return None stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) - new_iter = chain([modified_init_req], request_iterator) - resp_stream = stub.Datapath( - new_iter, metadata=[("client_id", client_id)]) - for resp in resp_stream: - yield resp + try: + with self.clients_lock: + self.num_clients += 1 + new_iter = chain([modified_init_req], request_iterator) + resp_stream = stub.Datapath( + new_iter, metadata=[("client_id", client_id)]) + for resp in resp_stream: + yield self.modify_connection_info_resp(resp) + finally: + with self.clients_lock: + logger.debug(f"Client detached: {client_id}") + self.num_clients -= 1 class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):