diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index fcfcb2ef9..aaa03c8e1 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1868,7 +1868,8 @@ def start_ray_client_server(redis_address, stdout_file=None, stderr_file=None, redis_password=None, - fate_share=None): + fate_share=None, + server_type="proxy"): """Run the server process of the Ray client. Args: @@ -1878,6 +1879,7 @@ def start_ray_client_server(redis_address, stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. redis_password (str): The password of the redis server. + server_type (str): Whether to start the proxy version of Ray Client. Returns: ProcessInfo for the process that was started. @@ -1885,7 +1887,7 @@ def start_ray_client_server(redis_address, command = [ sys.executable, "-m", "ray.util.client.server", "--redis-address=" + str(redis_address), - "--port=" + str(ray_client_server_port) + "--port=" + str(ray_client_server_port), "--mode=" + server_type ] if redis_password: command.append("--redis-password=" + redis_password) diff --git a/python/ray/serve/tests/test_ray_client.py b/python/ray/serve/tests/test_ray_client.py index f17cd7006..684519205 100644 --- a/python/ray/serve/tests/test_ray_client.py +++ b/python/ray/serve/tests/test_ray_client.py @@ -1,5 +1,6 @@ import random import subprocess +import sys import pytest import requests @@ -26,6 +27,7 @@ def ray_client_instance(): subprocess.check_output(["ray", "stop", "--force"]) +@pytest.mark.skipif(sys.platform != "linux", reason="Buggy on MacOS + Windows") def test_ray_client(ray_client_instance): ray.util.connect(ray_client_instance) diff --git a/python/ray/tests/test_client_terminate.py b/python/ray/tests/test_client_terminate.py index 193bdf6bb..a73b6eacc 100644 --- a/python/ray/tests/test_client_terminate.py +++ b/python/ray/tests/test_client_terminate.py @@ -113,7 +113,7 @@ def test_kill_cancel_metadata(ray_start_regular): pass def mock_terminate(term, metadata): - raise MetadataIsCorrectlyPassedException(metadata[0][0]) + raise MetadataIsCorrectlyPassedException(metadata[1][0]) # Mock stub's Terminate method to raise an exception. ray.api.worker.server.Terminate = mock_terminate diff --git a/python/ray/tests/test_runtime_env.py b/python/ray/tests/test_runtime_env.py index 6dbd9249c..6597a368e 100644 --- a/python/ray/tests/test_runtime_env.py +++ b/python/ray/tests/test_runtime_env.py @@ -523,11 +523,11 @@ sleep(600) # waiting it to be up sleep(5) runtime_env = f"""{{ "working_dir": "{working_dir}" }}""" - # Execute the second one which should trigger an error + # Execute the second one which should work because Ray Client servers. execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))" script = driver_script.format(**locals()) out = run_string_as_driver(script, env) - assert out.strip().split()[-1] == "ERROR" + assert out.strip().split()[-1] == "1000" proc.kill() proc.wait() @@ -573,14 +573,14 @@ sleep(600) sleep(5) runtime_env = f""" {{ "working_dir": test_module.__path__[0] }}""" # noqa: F541 - # Execute the following cmd in the second one which should - # fail + # Execute the following cmd in the second one and ensure that + # it is able to run. execute_statement = "print('OK')" script = driver_script.format(**locals()) out = run_string_as_driver(script, env) proc.kill() proc.wait() - assert out.strip().split()[-1] == "ERROR" + assert out.strip().split()[-1] == "OK" @unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.") diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index b4f5fe9a2..1329acacd 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) # This version string is incremented to indicate breaking changes in the # protocol that require upgrading the client version. -CURRENT_PROTOCOL_VERSION = "2021-04-19" +CURRENT_PROTOCOL_VERSION = "2021-05-18" class RayAPIStub: diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index 145f45ebb..ac42c2040 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -1,9 +1,12 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.util.client import ray from ray.util.client.options import validate_options import asyncio import concurrent.futures +from dataclasses import dataclass +import grpc import os import uuid import inspect @@ -421,3 +424,17 @@ def remote_decorator(options: Optional[Dict[str, Any]]): "either a function or to a class.") return decorator + + +@dataclass +class ClientServerHandle: + """Holds the handles to the registered gRPC servicers and their server.""" + task_servicer: ray_client_pb2_grpc.RayletDriverServicer + data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer + logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer + grpc_server: grpc.Server + + # Add a hook for all the cases that previously + # expected simply a gRPC server + def __getattr__(self, attr): + return getattr(self.grpc_server, attr) diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 14192a4f8..898bdc11a 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -63,7 +63,7 @@ class DataClient: stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel) resp_stream = stub.Datapath( iter(self.request_queue.get, None), - metadata=[("client_id", self._client_id)] + self._metadata, + metadata=self._metadata, wait_for_ready=True) try: for response in resp_stream: diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 550704986..c047fd2ef 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -49,7 +49,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): def Datapath(self, request_iterator, context): metadata = {k: v for k, v in context.invocation_metadata()} - client_id = metadata["client_id"] + client_id = metadata.get("client_id") or "" if client_id == "": logger.error("Client connecting with no client_id") return diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py new file mode 100644 index 000000000..ac6215446 --- /dev/null +++ b/python/ray/util/client/server/proxier.py @@ -0,0 +1,335 @@ +from concurrent import futures +from dataclasses import dataclass +import grpc +import logging +import json +from queue import Queue +import socket +from threading import Thread, Lock +import time +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +from ray.job_config import JobConfig +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray.util.client.common import (ClientServerHandle, + CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS) +from ray._private.services import ProcessInfo, start_ray_client_server +from ray._private.utils import detect_fate_sharing_support + +logger = logging.getLogger(__name__) + +CHECK_PROCESS_INTERVAL_S = 30 + +MIN_SPECIFIC_SERVER_PORT = 23000 +MAX_SPECIFIC_SERVER_PORT = 24000 + +CHECK_CHANNEL_TIMEOUT_S = 5 + + +def _get_client_id_from_context(context: Any) -> str: + """ + Get `client_id` from gRPC metadata. If the `client_id` is not present, + this function logs an error and sets the status_code. + """ + metadata = {k: v for k, v in context.invocation_metadata()} + client_id = metadata.get("client_id") or "" + if client_id == "": + logger.error("Client connecting with no client_id") + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + return client_id + + +@dataclass +class SpecificServer: + port: int + process_handle: ProcessInfo + channel: "grpc._channel.Channel" + + +class ProxyManager(): + def __init__(self, redis_address): + self.servers: Dict[str, SpecificServer] = dict() + self.server_lock = Lock() + self.redis_address = redis_address + self._free_ports: List[int] = list( + range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT)) + + self._check_thread = Thread(target=self._check_processes, daemon=True) + self._check_thread.start() + + self.fate_share = bool(detect_fate_sharing_support()) + + def _get_unused_port(self) -> int: + """ + Search for a port in _free_ports that is unused. + """ + with self.server_lock: + num_ports = len(self._free_ports) + for _ in range(num_ports): + port = self._free_ports.pop(0) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.bind(("", port)) + except OSError: + self._free_ports.append(port) + continue + finally: + s.close() + return port + raise RuntimeError("Unable to succeed in selecting a random port.") + + def start_specific_server(self, client_id) -> None: + """ + Start up a RayClient Server for an incoming client to + communicate with. + """ + port = self._get_unused_port() + specific_server = SpecificServer( + port=port, + process_handle=start_ray_client_server( + self.redis_address, + port, + fate_share=self.fate_share, + server_type="specific-server"), + channel=grpc.insecure_channel( + f"localhost:{port}", options=GRPC_OPTIONS)) + with self.server_lock: + self.servers[client_id] = specific_server + + def get_channel(self, client_id: str) -> Optional["grpc._channel.Channel"]: + """ + Find the gRPC Channel for the given client_id + """ + client = None + with self.server_lock: + client = self.servers.get(client_id) + if client is None: + logger.error(f"Unable to find channel for client: {client_id}") + return None + try: + grpc.channel_ready_future( + client.channel).result(timeout=CHECK_CHANNEL_TIMEOUT_S) + return client.channel + except grpc.FutureTimeoutError: + return None + + def _check_processes(self): + """ + Keeps the internal servers dictionary up-to-date with running servers. + """ + while True: + with self.server_lock: + for client_id, specific_server in list(self.servers.items()): + poll_result = specific_server.process_handle.process.poll() + if poll_result is not None: + del self.servers[client_id] + # Port is available to use again. + self._free_ports.append(specific_server.port) + + time.sleep(CHECK_PROCESS_INTERVAL_S) + + +class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer): + def __init__(self, ray_connect_handler: Callable, + proxy_manager: ProxyManager): + self.proxy_manager = proxy_manager + self.ray_connect_handler = ray_connect_handler + + def _call_inner_function( + self, request, context, + method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]: + client_id = _get_client_id_from_context(context) + chan = self.proxy_manager.get_channel(client_id) + if not chan: + logger.error(f"Channel for Client: {client_id} not found!") + context.set_code(grpc.StatusCode.NOT_FOUND) + return None + + stub = ray_client_pb2_grpc.RayletDriverStub(chan) + return getattr(stub, method)( + request, metadata=[("client_id", client_id)]) + + def Init(self, request, context=None) -> ray_client_pb2.InitResponse: + return self._call_inner_function(request, context, "Init") + + def PrepRuntimeEnv(self, request, + context=None) -> ray_client_pb2.PrepRuntimeEnvResponse: + return self._call_inner_function(request, context, "PrepRuntimeEnv") + + def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse: + return self._call_inner_function(request, context, "KVPut") + + def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse: + return self._call_inner_function(request, context, "KVGet") + + def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse: + return self._call_inner_function(request, context, "KVGet") + + def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse: + return self._call_inner_function(request, context, "KVList") + + def KVExists(self, request, + context=None) -> ray_client_pb2.KVExistsResponse: + return self._call_inner_function(request, context, "KVExists") + + def ClusterInfo(self, request, + context=None) -> ray_client_pb2.ClusterInfoResponse: + + # NOTE: We need to respond to the PING request here to allow the client + # to continue with connecting. + if request.type == ray_client_pb2.ClusterInfoType.PING: + resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({})) + return resp + return self._call_inner_function(request, context, "ClusterInfo") + + def Terminate(self, req, context=None): + return self._call_inner_function(req, context, "Terminate") + + def GetObject(self, request, context=None): + return self._call_inner_function(request, context, "GetObject") + + def PutObject(self, request: ray_client_pb2.PutRequest, + context=None) -> ray_client_pb2.PutResponse: + return self._call_inner_function(request, context, "PutObject") + + def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: + return self._call_inner_function(request, context, "WaitObject") + + def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: + return self._call_inner_function(task, context, "Schedule") + + +def forward_streaming_requests(grpc_input_generator: Iterator[Any], + output_queue: "Queue") -> None: + """ + Forwards streaming requests from the grpc_input_generator into the + output_queue. + """ + try: + for req in grpc_input_generator: + output_queue.put(req) + except grpc.RpcError as e: + logger.debug("closing dataservicer reader thread " + f"grpc error reading request_iterator: {e}") + finally: + # Set the sentinel value for the output_queue + output_queue.put(None) + + +def prepare_runtime_init_req(req: ray_client_pb2.InitRequest + ) -> Tuple[ray_client_pb2.InitRequest, JobConfig]: + """ + Extract JobConfig and possibly mutate InitRequest before it is passed to + the specific RayClient Server. + """ + job_config = JobConfig() + if req.job_config: + import pickle + job_config = pickle.loads(req.job_config) + + return (req, job_config) + + +class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): + def __init__(self, proxy_manager: ProxyManager): + self.proxy_manager = proxy_manager + + def Datapath(self, request_iterator, context): + client_id = _get_client_id_from_context(context) + if client_id == "": + return + logger.debug(f"New data connection from client {client_id}: ") + + init_req = next(request_iterator) + init_type = init_req.WhichOneof("type") + assert init_type == "init", ("Received initial message of type " + f"{init_type}, not 'init'.") + + modified_init_req, job_config = prepare_runtime_init_req(init_req.init) + init_req.init.CopyFrom(modified_init_req) + queue = Queue() + queue.put(init_req) + + self.proxy_manager.start_specific_server(client_id) + + channel = self.proxy_manager.get_channel(client_id) + if channel is None: + context.set_code(grpc.StatusCode.NOT_FOUND) + return None + stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel) + thread = Thread( + target=forward_streaming_requests, + args=(request_iterator, queue), + daemon=True) + thread.start() + + resp_stream = stub.Datapath( + iter(queue.get, None), metadata=[("client_id", client_id)]) + for resp in resp_stream: + yield resp + + +class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer): + def __init__(self, proxy_manager: ProxyManager): + super().__init__() + self.proxy_manager = proxy_manager + + def Logstream(self, request_iterator, context): + client_id = _get_client_id_from_context(context) + if client_id == "": + return + logger.debug(f"New data connection from client {client_id}: ") + + channel = None + for i in range(10): + # TODO(ilr) Ensure LogClient starts after startup has happened. + # This will remove the need for retries here. + channel = self.proxy_manager.get_channel(client_id) + + if channel is not None: + break + logger.warning( + f"Retrying Logstream connection. {i+1} attempts failed.") + time.sleep(5) + + if channel is None: + context.set_code(grpc.StatusCode.NOT_FOUND) + return None + + stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel) + queue = Queue() + thread = Thread( + target=forward_streaming_requests, + args=(request_iterator, queue), + daemon=True) + thread.start() + + resp_stream = stub.Logstream( + iter(queue.get, None), metadata=[("client_id", client_id)]) + for resp in resp_stream: + yield resp + + +def serve_proxier(connection_str: str, redis_address: str): + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS), + options=GRPC_OPTIONS) + proxy_manager = ProxyManager(redis_address) + task_servicer = RayletServicerProxy(None, proxy_manager) + data_servicer = DataServicerProxy(proxy_manager) + logs_servicer = LogstreamServicerProxy(proxy_manager) + ray_client_pb2_grpc.add_RayletDriverServicer_to_server( + task_servicer, server) + ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( + data_servicer, server) + ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( + logs_servicer, server) + server.add_insecure_port(connection_str) + server.start() + return ClientServerHandle( + task_servicer=task_servicer, + data_servicer=data_servicer, + logs_servicer=logs_servicer, + grpc_server=server, + ) diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 273681056..8a36a7b52 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -3,7 +3,6 @@ from concurrent import futures import grpc import base64 from collections import defaultdict -from dataclasses import dataclass import os import queue @@ -22,7 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time import inspect import json -from ray.util.client.common import (GRPC_OPTIONS, CLIENT_SERVER_MAX_THREADS) +from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS, + CLIENT_SERVER_MAX_THREADS) +from ray.util.client.server.proxier import serve_proxier from ray.util.client.server.server_pickler import convert_from_arg from ray.util.client.server.server_pickler import dumps_from_server from ray.util.client.server.server_pickler import loads_from_client @@ -34,6 +35,8 @@ from ray._private.client_mode_hook import disable_client_hook logger = logging.getLogger(__name__) +TIMEOUT_FOR_SPECIFIC_SERVER_S = 30 + class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def __init__(self, ray_connect_handler: Callable): @@ -170,6 +173,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): data = ray.is_initialized() elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE: data = ray.timeline() + elif request.type == ray_client_pb2.ClusterInfoType.PING: + data = {} else: raise TypeError("Unsupported cluster info type") return json.dumps(data) @@ -560,20 +565,6 @@ def decode_options( return opts -@dataclass -class ClientServerHandle: - """Holds the handles to the registered gRPC servicers and their server.""" - task_servicer: RayletServicer - data_servicer: DataServicer - logs_servicer: LogstreamServicer - grpc_server: grpc.Server - - # Add a hook for all the cases that previously - # expected simply a gRPC server - def __getattr__(self, attr): - return getattr(self.grpc_server, attr) - - def serve(connection_str, ray_connect_handler=None): def default_connect_handler(job_config: JobConfig = None): with disable_client_hook(): @@ -666,6 +657,11 @@ def main(): "--host", type=str, default="0.0.0.0", help="Host IP to bind to") parser.add_argument( "-p", "--port", type=int, default=50051, help="Port to bind to") + parser.add_argument( + "--mode", + type=str, + choices=["proxy", "legacy", "specific-server"], + default="proxy") parser.add_argument( "--redis-address", required=False, @@ -689,8 +685,13 @@ def main(): hostport = "%s:%d" % (args.host, args.port) logger.info(f"Starting Ray Client server on {hostport}") - server = serve(hostport, ray_connect_handler) + if args.mode == "proxy": + server = serve_proxier(hostport, args.redis_address) + else: + server = serve(hostport, ray_connect_handler) + try: + idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S while True: health_report = { "time": time.time(), @@ -706,6 +707,18 @@ def main(): logger.exception(e) time.sleep(1) + if args.mode == "specific-server": + if server.data_servicer.num_clients > 0: + idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S + else: + idle_checks_remaining -= 1 + if idle_checks_remaining == 0: + raise KeyboardInterrupt() + if (idle_checks_remaining % 5 == 0 and idle_checks_remaining != + TIMEOUT_FOR_SPECIFIC_SERVER_S): + logger.info( + f"{idle_checks_remaining} idle checks before shutdown." + ) except KeyboardInterrupt: server.stop(0) diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 9389b95e8..04254efe7 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -77,11 +77,12 @@ class Worker: at least once. For infinite retries, catch the ConnectionError exception. """ - self.metadata = metadata if metadata else [] + self._client_id = make_client_id() + self.metadata = [("client_id", self._client_id)] + (metadata if + metadata else []) self.channel = None self.server = None self._conn_state = grpc.ChannelConnectivity.IDLE - self._client_id = make_client_id() self._converted: Dict[str, ClientStub] = {} if secure: @@ -439,8 +440,7 @@ class Worker: """ if self.server is not None: logger.debug("Pinging server.") - result = self.get_cluster_info( - ray_client_pb2.ClusterInfoType.IS_INITIALIZED) + result = self.get_cluster_info(ray_client_pb2.ClusterInfoType.PING) return result is not None return False diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index 22ec99733..41b3f7d29 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -157,6 +157,7 @@ message ClusterInfoType { AVAILABLE_RESOURCES = 3; RUNTIME_CONTEXT = 4; TIMELINE = 5; + PING = 6; } }