diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 0f70fd77f..6a0fd00a3 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -32,6 +32,7 @@ def test_proxy_manager_lifecycle(shutdown_only): pm._free_ports = [45000, 45001] client = "client1" + pm.create_specific_server(client) assert pm.start_specific_server(client, JobConfig()) # Channel should be ready and corresponding to an existing server grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5) @@ -39,7 +40,7 @@ def test_proxy_manager_lifecycle(shutdown_only): proc = pm._get_server_for_client(client) assert proc.port == 45000 - proc.process_handle().process.wait(10) + proc.process_handle_future.result().process.wait(10) # Wait for reconcile loop time.sleep(2) @@ -63,6 +64,7 @@ def test_proxy_manager_bad_startup(shutdown_only): pm._free_ports = [46000, 46001] client = "client1" + pm.create_specific_server(client) assert not pm.start_specific_server( client, JobConfig( @@ -122,6 +124,40 @@ def test_correct_num_clients(call_ray_start): run_string_as_driver(check_we_are_second.format(num_clients=1)) +check_connection = """ +import ray +ray.client("localhost:25010").connect() +assert ray.util.client.ray.worker.log_client.log_thread.is_alive() +""" + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="PSUtil does not work the same on windows.") +def test_delay_in_rewriting_environment(shutdown_only): + """ + Check that a delay in `ray_client_server_env_prep` does not break + a Client connecting. + """ + proxier.LOGSTREAM_RETRIES = 3 + proxier.LOGSTREAM_RETRY_INTERVAL_SEC = 1 + ray_instance = ray.init() + + def delay_in_rewrite(input: JobConfig): + import time + time.sleep(6) + return input + + proxier.ray_client_server_env_prep = delay_in_rewrite + + server = proxier.serve_proxier("localhost:25010", + ray_instance["redis_address"], + ray_instance["session_dir"]) + + run_string_as_driver(check_connection) + server.stop(0) + + def test_prepare_runtime_init_req_fails(): """ Check that a connection that is initiated with a non-Init request diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 8c36a0484..e2ae4d8cf 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -31,6 +31,9 @@ MAX_SPECIFIC_SERVER_PORT = 24000 CHECK_CHANNEL_TIMEOUT_S = 10 +LOGSTREAM_RETRIES = 5 +LOGSTREAM_RETRY_INTERVAL_SEC = 2 + def _get_client_id_from_context(context: Any) -> str: """ @@ -55,11 +58,34 @@ class SpecificServer: """ Wait for the server to actually start up. """ - self.process_handle_future.result(timeout=timeout) - return + res = self.process_handle_future.result(timeout=timeout) + if res is None: + # This is only set to none when server creation specifically fails. + raise RuntimeError("Server startup failed.") - def process_handle(self) -> ProcessInfo: - return self.process_handle_future.result() + def poll(self) -> Optional[int]: + """Check if the process has exited.""" + try: + proc = self.process_handle_future.result(timeout=0.1) + if proc is not None: + return proc.process.poll() + except futures.TimeoutError: + return + + def kill(self) -> None: + """Try to send a KILL signal to the process.""" + try: + proc = self.process_handle_future.result(timeout=0.1) + if proc is not None: + proc.process.kill() + except futures.TimeoutError: + # Server has not been started yet. + pass + + def set_result(self, proc: Optional[ProcessInfo]) -> None: + """Set the result of the internal future if it is currently unset.""" + if not self.process_handle_future.done(): + self.process_handle_future.set_result(proc) def _match_running_client_server(command: List[str]) -> bool: @@ -136,27 +162,37 @@ class ProxyManager(): self._session_dir = connection_tuple["session_dir"] return self._session_dir + def create_specific_server(self, client_id: str) -> SpecificServer: + """ + Create, but not start a SpecificServer for a given client. This + method must be called once per client. + """ + with self.server_lock: + assert self.servers.get(client_id) is None, ( + f"Server already created for Client: {client_id}") + port = self._get_unused_port() + server = SpecificServer( + port=port, + process_handle_future=futures.Future(), + channel=grpc.insecure_channel( + f"localhost:{port}", options=GRPC_OPTIONS)) + self.servers[client_id] = server + return server + def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool: """ Start up a RayClient Server for an incoming client to communicate with. Returns whether creation was successful. """ - with self.server_lock: - port = self._get_unused_port() - handle_ready = futures.Future() - specific_server = SpecificServer( - port=port, - process_handle_future=handle_ready, - channel=grpc.insecure_channel( - f"localhost:{port}", options=GRPC_OPTIONS)) - self.servers[client_id] = specific_server + specific_server = self._get_server_for_client(client_id) + assert specific_server, f"Server has not been created for: {client_id}" serialized_runtime_env = job_config.get_serialized_runtime_env() proc = start_ray_client_server( self._get_redis_address(), - port, + specific_server.port, fate_share=self.fate_share, server_type="specific-server", serialized_runtime_env=serialized_runtime_env, @@ -181,9 +217,9 @@ class ProxyManager(): logger.debug( "Waiting for Process to reach the actual client server.") time.sleep(0.5) - handle_ready.set_result(proc) - logger.info(f"SpecificServer started on port: {port} with PID: {pid} " - f"for client: {client_id}") + specific_server.set_result(proc) + logger.info(f"SpecificServer started on port: {specific_server.port} " + f"with PID: {pid} for client: {client_id}") return proc.process.poll() is None def _get_server_for_client(self, @@ -205,6 +241,7 @@ class ProxyManager(): server = self._get_server_for_client(client_id) if server is None: return None + # Wait for the SpecificServer to become ready. server.wait_ready() try: grpc.channel_ready_future( @@ -221,9 +258,7 @@ class ProxyManager(): while True: with self.server_lock: for client_id, specific_server in list(self.servers.items()): - poll_result = specific_server.process_handle( - ).process.poll() - if poll_result is not None: + if specific_server.poll() is not None: del self.servers[client_id] # Port is available to use again. self._free_ports.append(specific_server.port) @@ -236,12 +271,7 @@ class ProxyManager(): for platforms where fate sharing is not supported. """ for server in self.servers.values(): - try: - server.wait_ready(0.1) - server.process_handle().process.kill() - except TimeoutError: - # Server has not been started yet. - pass + server.kill() class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer): @@ -367,31 +397,37 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): if client_id == "": return - logger.info(f"New data connection from client {client_id}: ") - modified_init_req, job_config = prepare_runtime_init_req( - request_iterator) - - if not self.proxy_manager.start_specific_server(client_id, job_config): - logger.error(f"Server startup failed for client: {client_id}, " - f"using JobConfig: {job_config}!") - context.set_code(grpc.StatusCode.ABORTED) - return None - - channel = self.proxy_manager.get_channel(client_id) - if channel is None: - logger.error(f"Channel not found for {client_id}") - context.set_code(grpc.StatusCode.NOT_FOUND) - return None - stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) + # Create Placeholder *before* reading the first request. + server = self.proxy_manager.create_specific_server(client_id) try: with self.clients_lock: self.num_clients += 1 + + logger.info(f"New data connection from client {client_id}: ") + modified_init_req, job_config = prepare_runtime_init_req( + request_iterator) + + if not self.proxy_manager.start_specific_server( + client_id, job_config): + logger.error(f"Server startup failed for client: {client_id}, " + f"using JobConfig: {job_config}!") + context.set_code(grpc.StatusCode.ABORTED) + return None + + channel = self.proxy_manager.get_channel(client_id) + if channel is None: + logger.error(f"Channel not found for {client_id}") + context.set_code(grpc.StatusCode.NOT_FOUND) + return None + stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) + new_iter = chain([modified_init_req], request_iterator) resp_stream = stub.Datapath( new_iter, metadata=[("client_id", client_id)]) for resp in resp_stream: yield self.modify_connection_info_resp(resp) finally: + server.set_result(None) with self.clients_lock: logger.debug(f"Client detached: {client_id}") self.num_clients -= 1 @@ -406,22 +442,22 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer): client_id = _get_client_id_from_context(context) if client_id == "": return - logger.debug(f"New data connection from client {client_id}: ") + logger.debug(f"New logstream connection from client {client_id}: ") channel = None # We need to retry a few times because the LogClient *may* connect # Before the DataClient has finished connecting. - for i in range(5): + for i in range(LOGSTREAM_RETRIES): channel = self.proxy_manager.get_channel(client_id) if channel is not None: break logger.warning( f"Retrying Logstream connection. {i+1} attempts failed.") - time.sleep(2) + time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC) if channel is None: - context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_code(grpc.StatusCode.UNAVAILABLE) return None stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel) @@ -432,11 +468,13 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer): yield resp -def serve_proxier(connection_str: str, redis_address: str): +def serve_proxier(connection_str: str, + redis_address: str, + session_dir: Optional[str] = None): server = grpc.server( futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS), options=GRPC_OPTIONS) - proxy_manager = ProxyManager(redis_address) + proxy_manager = ProxyManager(redis_address, session_dir) task_servicer = RayletServicerProxy(None, proxy_manager) data_servicer = DataServicerProxy(proxy_manager) logs_servicer = LogstreamServicerProxy(proxy_manager) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 04254efe7..9a9fd2e8a 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -341,8 +341,8 @@ class Worker: self.reference_count[id] += 1 def close(self): - self.log_client.close() self.data_client.close() + self.log_client.close() if self.channel: self.channel.close() self.channel = None