[client] One Driver per RayClient Server (#15875)

This commit is contained in:
Ian Rodney 2021-05-19 09:03:09 -07:00 committed by GitHub
parent c636bc3065
commit 97d1414f23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 402 additions and 32 deletions

View file

@ -1868,7 +1868,8 @@ def start_ray_client_server(redis_address,
stdout_file=None,
stderr_file=None,
redis_password=None,
fate_share=None):
fate_share=None,
server_type="proxy"):
"""Run the server process of the Ray client.
Args:
@ -1878,6 +1879,7 @@ def start_ray_client_server(redis_address,
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
redis_password (str): The password of the redis server.
server_type (str): Whether to start the proxy version of Ray Client.
Returns:
ProcessInfo for the process that was started.
@ -1885,7 +1887,7 @@ def start_ray_client_server(redis_address,
command = [
sys.executable, "-m", "ray.util.client.server",
"--redis-address=" + str(redis_address),
"--port=" + str(ray_client_server_port)
"--port=" + str(ray_client_server_port), "--mode=" + server_type
]
if redis_password:
command.append("--redis-password=" + redis_password)

View file

@ -1,5 +1,6 @@
import random
import subprocess
import sys
import pytest
import requests
@ -26,6 +27,7 @@ def ray_client_instance():
subprocess.check_output(["ray", "stop", "--force"])
@pytest.mark.skipif(sys.platform != "linux", reason="Buggy on MacOS + Windows")
def test_ray_client(ray_client_instance):
ray.util.connect(ray_client_instance)

View file

@ -113,7 +113,7 @@ def test_kill_cancel_metadata(ray_start_regular):
pass
def mock_terminate(term, metadata):
raise MetadataIsCorrectlyPassedException(metadata[0][0])
raise MetadataIsCorrectlyPassedException(metadata[1][0])
# Mock stub's Terminate method to raise an exception.
ray.api.worker.server.Terminate = mock_terminate

View file

@ -523,11 +523,11 @@ sleep(600)
# waiting it to be up
sleep(5)
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
# Execute the second one which should trigger an error
# Execute the second one which should work because Ray Client servers.
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
assert out.strip().split()[-1] == "ERROR"
assert out.strip().split()[-1] == "1000"
proc.kill()
proc.wait()
@ -573,14 +573,14 @@ sleep(600)
sleep(5)
runtime_env = f"""
{{ "working_dir": test_module.__path__[0] }}""" # noqa: F541
# Execute the following cmd in the second one which should
# fail
# Execute the following cmd in the second one and ensure that
# it is able to run.
execute_statement = "print('OK')"
script = driver_script.format(**locals())
out = run_string_as_driver(script, env)
proc.kill()
proc.wait()
assert out.strip().split()[-1] == "ERROR"
assert out.strip().split()[-1] == "OK"
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")

View file

@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2021-04-19"
CURRENT_PROTOCOL_VERSION = "2021-05-18"
class RayAPIStub:

View file

@ -1,9 +1,12 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client import ray
from ray.util.client.options import validate_options
import asyncio
import concurrent.futures
from dataclasses import dataclass
import grpc
import os
import uuid
import inspect
@ -421,3 +424,17 @@ def remote_decorator(options: Optional[Dict[str, Any]]):
"either a function or to a class.")
return decorator
@dataclass
class ClientServerHandle:
"""Holds the handles to the registered gRPC servicers and their server."""
task_servicer: ray_client_pb2_grpc.RayletDriverServicer
data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer
logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
grpc_server: grpc.Server
# Add a hook for all the cases that previously
# expected simply a gRPC server
def __getattr__(self, attr):
return getattr(self.grpc_server, attr)

View file

@ -63,7 +63,7 @@ class DataClient:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=[("client_id", self._client_id)] + self._metadata,
metadata=self._metadata,
wait_for_ready=True)
try:
for response in resp_stream:

View file

@ -49,7 +49,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata["client_id"]
client_id = metadata.get("client_id") or ""
if client_id == "":
logger.error("Client connecting with no client_id")
return

View file

@ -0,0 +1,335 @@
from concurrent import futures
from dataclasses import dataclass
import grpc
import logging
import json
from queue import Queue
import socket
from threading import Thread, Lock
import time
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from ray.job_config import JobConfig
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client.common import (ClientServerHandle,
CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS)
from ray._private.services import ProcessInfo, start_ray_client_server
from ray._private.utils import detect_fate_sharing_support
logger = logging.getLogger(__name__)
CHECK_PROCESS_INTERVAL_S = 30
MIN_SPECIFIC_SERVER_PORT = 23000
MAX_SPECIFIC_SERVER_PORT = 24000
CHECK_CHANNEL_TIMEOUT_S = 5
def _get_client_id_from_context(context: Any) -> str:
"""
Get `client_id` from gRPC metadata. If the `client_id` is not present,
this function logs an error and sets the status_code.
"""
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata.get("client_id") or ""
if client_id == "":
logger.error("Client connecting with no client_id")
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return client_id
@dataclass
class SpecificServer:
port: int
process_handle: ProcessInfo
channel: "grpc._channel.Channel"
class ProxyManager():
def __init__(self, redis_address):
self.servers: Dict[str, SpecificServer] = dict()
self.server_lock = Lock()
self.redis_address = redis_address
self._free_ports: List[int] = list(
range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT))
self._check_thread = Thread(target=self._check_processes, daemon=True)
self._check_thread.start()
self.fate_share = bool(detect_fate_sharing_support())
def _get_unused_port(self) -> int:
"""
Search for a port in _free_ports that is unused.
"""
with self.server_lock:
num_ports = len(self._free_ports)
for _ in range(num_ports):
port = self._free_ports.pop(0)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(("", port))
except OSError:
self._free_ports.append(port)
continue
finally:
s.close()
return port
raise RuntimeError("Unable to succeed in selecting a random port.")
def start_specific_server(self, client_id) -> None:
"""
Start up a RayClient Server for an incoming client to
communicate with.
"""
port = self._get_unused_port()
specific_server = SpecificServer(
port=port,
process_handle=start_ray_client_server(
self.redis_address,
port,
fate_share=self.fate_share,
server_type="specific-server"),
channel=grpc.insecure_channel(
f"localhost:{port}", options=GRPC_OPTIONS))
with self.server_lock:
self.servers[client_id] = specific_server
def get_channel(self, client_id: str) -> Optional["grpc._channel.Channel"]:
"""
Find the gRPC Channel for the given client_id
"""
client = None
with self.server_lock:
client = self.servers.get(client_id)
if client is None:
logger.error(f"Unable to find channel for client: {client_id}")
return None
try:
grpc.channel_ready_future(
client.channel).result(timeout=CHECK_CHANNEL_TIMEOUT_S)
return client.channel
except grpc.FutureTimeoutError:
return None
def _check_processes(self):
"""
Keeps the internal servers dictionary up-to-date with running servers.
"""
while True:
with self.server_lock:
for client_id, specific_server in list(self.servers.items()):
poll_result = specific_server.process_handle.process.poll()
if poll_result is not None:
del self.servers[client_id]
# Port is available to use again.
self._free_ports.append(specific_server.port)
time.sleep(CHECK_PROCESS_INTERVAL_S)
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, ray_connect_handler: Callable,
proxy_manager: ProxyManager):
self.proxy_manager = proxy_manager
self.ray_connect_handler = ray_connect_handler
def _call_inner_function(
self, request, context,
method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
client_id = _get_client_id_from_context(context)
chan = self.proxy_manager.get_channel(client_id)
if not chan:
logger.error(f"Channel for Client: {client_id} not found!")
context.set_code(grpc.StatusCode.NOT_FOUND)
return None
stub = ray_client_pb2_grpc.RayletDriverStub(chan)
return getattr(stub, method)(
request, metadata=[("client_id", client_id)])
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
return self._call_inner_function(request, context, "Init")
def PrepRuntimeEnv(self, request,
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
return self._call_inner_function(request, context, "PrepRuntimeEnv")
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
return self._call_inner_function(request, context, "KVPut")
def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
return self._call_inner_function(request, context, "KVGet")
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
return self._call_inner_function(request, context, "KVGet")
def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
return self._call_inner_function(request, context, "KVList")
def KVExists(self, request,
context=None) -> ray_client_pb2.KVExistsResponse:
return self._call_inner_function(request, context, "KVExists")
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
# NOTE: We need to respond to the PING request here to allow the client
# to continue with connecting.
if request.type == ray_client_pb2.ClusterInfoType.PING:
resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({}))
return resp
return self._call_inner_function(request, context, "ClusterInfo")
def Terminate(self, req, context=None):
return self._call_inner_function(req, context, "Terminate")
def GetObject(self, request, context=None):
return self._call_inner_function(request, context, "GetObject")
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
return self._call_inner_function(request, context, "PutObject")
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
return self._call_inner_function(request, context, "WaitObject")
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
return self._call_inner_function(task, context, "Schedule")
def forward_streaming_requests(grpc_input_generator: Iterator[Any],
output_queue: "Queue") -> None:
"""
Forwards streaming requests from the grpc_input_generator into the
output_queue.
"""
try:
for req in grpc_input_generator:
output_queue.put(req)
except grpc.RpcError as e:
logger.debug("closing dataservicer reader thread "
f"grpc error reading request_iterator: {e}")
finally:
# Set the sentinel value for the output_queue
output_queue.put(None)
def prepare_runtime_init_req(req: ray_client_pb2.InitRequest
) -> Tuple[ray_client_pb2.InitRequest, JobConfig]:
"""
Extract JobConfig and possibly mutate InitRequest before it is passed to
the specific RayClient Server.
"""
job_config = JobConfig()
if req.job_config:
import pickle
job_config = pickle.loads(req.job_config)
return (req, job_config)
class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, proxy_manager: ProxyManager):
self.proxy_manager = proxy_manager
def Datapath(self, request_iterator, context):
client_id = _get_client_id_from_context(context)
if client_id == "":
return
logger.debug(f"New data connection from client {client_id}: ")
init_req = next(request_iterator)
init_type = init_req.WhichOneof("type")
assert init_type == "init", ("Received initial message of type "
f"{init_type}, not 'init'.")
modified_init_req, job_config = prepare_runtime_init_req(init_req.init)
init_req.init.CopyFrom(modified_init_req)
queue = Queue()
queue.put(init_req)
self.proxy_manager.start_specific_server(client_id)
channel = self.proxy_manager.get_channel(client_id)
if channel is None:
context.set_code(grpc.StatusCode.NOT_FOUND)
return None
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
thread = Thread(
target=forward_streaming_requests,
args=(request_iterator, queue),
daemon=True)
thread.start()
resp_stream = stub.Datapath(
iter(queue.get, None), metadata=[("client_id", client_id)])
for resp in resp_stream:
yield resp
class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
def __init__(self, proxy_manager: ProxyManager):
super().__init__()
self.proxy_manager = proxy_manager
def Logstream(self, request_iterator, context):
client_id = _get_client_id_from_context(context)
if client_id == "":
return
logger.debug(f"New data connection from client {client_id}: ")
channel = None
for i in range(10):
# TODO(ilr) Ensure LogClient starts after startup has happened.
# This will remove the need for retries here.
channel = self.proxy_manager.get_channel(client_id)
if channel is not None:
break
logger.warning(
f"Retrying Logstream connection. {i+1} attempts failed.")
time.sleep(5)
if channel is None:
context.set_code(grpc.StatusCode.NOT_FOUND)
return None
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
queue = Queue()
thread = Thread(
target=forward_streaming_requests,
args=(request_iterator, queue),
daemon=True)
thread.start()
resp_stream = stub.Logstream(
iter(queue.get, None), metadata=[("client_id", client_id)])
for resp in resp_stream:
yield resp
def serve_proxier(connection_str: str, redis_address: str):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
options=GRPC_OPTIONS)
proxy_manager = ProxyManager(redis_address)
task_servicer = RayletServicerProxy(None, proxy_manager)
data_servicer = DataServicerProxy(proxy_manager)
logs_servicer = LogstreamServicerProxy(proxy_manager)
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
data_servicer, server)
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
logs_servicer, server)
server.add_insecure_port(connection_str)
server.start()
return ClientServerHandle(
task_servicer=task_servicer,
data_servicer=data_servicer,
logs_servicer=logs_servicer,
grpc_server=server,
)

View file

@ -3,7 +3,6 @@ from concurrent import futures
import grpc
import base64
from collections import defaultdict
from dataclasses import dataclass
import os
import queue
@ -22,7 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import inspect
import json
from ray.util.client.common import (GRPC_OPTIONS, CLIENT_SERVER_MAX_THREADS)
from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS,
CLIENT_SERVER_MAX_THREADS)
from ray.util.client.server.proxier import serve_proxier
from ray.util.client.server.server_pickler import convert_from_arg
from ray.util.client.server.server_pickler import dumps_from_server
from ray.util.client.server.server_pickler import loads_from_client
@ -34,6 +35,8 @@ from ray._private.client_mode_hook import disable_client_hook
logger = logging.getLogger(__name__)
TIMEOUT_FOR_SPECIFIC_SERVER_S = 30
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, ray_connect_handler: Callable):
@ -170,6 +173,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
data = ray.is_initialized()
elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE:
data = ray.timeline()
elif request.type == ray_client_pb2.ClusterInfoType.PING:
data = {}
else:
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
@ -560,20 +565,6 @@ def decode_options(
return opts
@dataclass
class ClientServerHandle:
"""Holds the handles to the registered gRPC servicers and their server."""
task_servicer: RayletServicer
data_servicer: DataServicer
logs_servicer: LogstreamServicer
grpc_server: grpc.Server
# Add a hook for all the cases that previously
# expected simply a gRPC server
def __getattr__(self, attr):
return getattr(self.grpc_server, attr)
def serve(connection_str, ray_connect_handler=None):
def default_connect_handler(job_config: JobConfig = None):
with disable_client_hook():
@ -666,6 +657,11 @@ def main():
"--host", type=str, default="0.0.0.0", help="Host IP to bind to")
parser.add_argument(
"-p", "--port", type=int, default=50051, help="Port to bind to")
parser.add_argument(
"--mode",
type=str,
choices=["proxy", "legacy", "specific-server"],
default="proxy")
parser.add_argument(
"--redis-address",
required=False,
@ -689,8 +685,13 @@ def main():
hostport = "%s:%d" % (args.host, args.port)
logger.info(f"Starting Ray Client server on {hostport}")
server = serve(hostport, ray_connect_handler)
if args.mode == "proxy":
server = serve_proxier(hostport, args.redis_address)
else:
server = serve(hostport, ray_connect_handler)
try:
idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
while True:
health_report = {
"time": time.time(),
@ -706,6 +707,18 @@ def main():
logger.exception(e)
time.sleep(1)
if args.mode == "specific-server":
if server.data_servicer.num_clients > 0:
idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
else:
idle_checks_remaining -= 1
if idle_checks_remaining == 0:
raise KeyboardInterrupt()
if (idle_checks_remaining % 5 == 0 and idle_checks_remaining !=
TIMEOUT_FOR_SPECIFIC_SERVER_S):
logger.info(
f"{idle_checks_remaining} idle checks before shutdown."
)
except KeyboardInterrupt:
server.stop(0)

View file

@ -77,11 +77,12 @@ class Worker:
at least once. For infinite retries, catch the ConnectionError
exception.
"""
self.metadata = metadata if metadata else []
self._client_id = make_client_id()
self.metadata = [("client_id", self._client_id)] + (metadata if
metadata else [])
self.channel = None
self.server = None
self._conn_state = grpc.ChannelConnectivity.IDLE
self._client_id = make_client_id()
self._converted: Dict[str, ClientStub] = {}
if secure:
@ -439,8 +440,7 @@ class Worker:
"""
if self.server is not None:
logger.debug("Pinging server.")
result = self.get_cluster_info(
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
result = self.get_cluster_info(ray_client_pb2.ClusterInfoType.PING)
return result is not None
return False

View file

@ -157,6 +157,7 @@ message ClusterInfoType {
AVAILABLE_RESOURCES = 3;
RUNTIME_CONTEXT = 4;
TIMELINE = 5;
PING = 6;
}
}