mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
[client] One Driver per RayClient Server (#15875)
This commit is contained in:
parent
c636bc3065
commit
97d1414f23
12 changed files with 402 additions and 32 deletions
|
@ -1868,7 +1868,8 @@ def start_ray_client_server(redis_address,
|
||||||
stdout_file=None,
|
stdout_file=None,
|
||||||
stderr_file=None,
|
stderr_file=None,
|
||||||
redis_password=None,
|
redis_password=None,
|
||||||
fate_share=None):
|
fate_share=None,
|
||||||
|
server_type="proxy"):
|
||||||
"""Run the server process of the Ray client.
|
"""Run the server process of the Ray client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1878,6 +1879,7 @@ def start_ray_client_server(redis_address,
|
||||||
stderr_file: A file handle opened for writing to redirect stderr to. If
|
stderr_file: A file handle opened for writing to redirect stderr to. If
|
||||||
no redirection should happen, then this should be None.
|
no redirection should happen, then this should be None.
|
||||||
redis_password (str): The password of the redis server.
|
redis_password (str): The password of the redis server.
|
||||||
|
server_type (str): Whether to start the proxy version of Ray Client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ProcessInfo for the process that was started.
|
ProcessInfo for the process that was started.
|
||||||
|
@ -1885,7 +1887,7 @@ def start_ray_client_server(redis_address,
|
||||||
command = [
|
command = [
|
||||||
sys.executable, "-m", "ray.util.client.server",
|
sys.executable, "-m", "ray.util.client.server",
|
||||||
"--redis-address=" + str(redis_address),
|
"--redis-address=" + str(redis_address),
|
||||||
"--port=" + str(ray_client_server_port)
|
"--port=" + str(ray_client_server_port), "--mode=" + server_type
|
||||||
]
|
]
|
||||||
if redis_password:
|
if redis_password:
|
||||||
command.append("--redis-password=" + redis_password)
|
command.append("--redis-password=" + redis_password)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import random
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
@ -26,6 +27,7 @@ def ray_client_instance():
|
||||||
subprocess.check_output(["ray", "stop", "--force"])
|
subprocess.check_output(["ray", "stop", "--force"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform != "linux", reason="Buggy on MacOS + Windows")
|
||||||
def test_ray_client(ray_client_instance):
|
def test_ray_client(ray_client_instance):
|
||||||
ray.util.connect(ray_client_instance)
|
ray.util.connect(ray_client_instance)
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ def test_kill_cancel_metadata(ray_start_regular):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mock_terminate(term, metadata):
|
def mock_terminate(term, metadata):
|
||||||
raise MetadataIsCorrectlyPassedException(metadata[0][0])
|
raise MetadataIsCorrectlyPassedException(metadata[1][0])
|
||||||
|
|
||||||
# Mock stub's Terminate method to raise an exception.
|
# Mock stub's Terminate method to raise an exception.
|
||||||
ray.api.worker.server.Terminate = mock_terminate
|
ray.api.worker.server.Terminate = mock_terminate
|
||||||
|
|
|
@ -523,11 +523,11 @@ sleep(600)
|
||||||
# waiting it to be up
|
# waiting it to be up
|
||||||
sleep(5)
|
sleep(5)
|
||||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||||
# Execute the second one which should trigger an error
|
# Execute the second one which should work because Ray Client servers.
|
||||||
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
||||||
script = driver_script.format(**locals())
|
script = driver_script.format(**locals())
|
||||||
out = run_string_as_driver(script, env)
|
out = run_string_as_driver(script, env)
|
||||||
assert out.strip().split()[-1] == "ERROR"
|
assert out.strip().split()[-1] == "1000"
|
||||||
proc.kill()
|
proc.kill()
|
||||||
proc.wait()
|
proc.wait()
|
||||||
|
|
||||||
|
@ -573,14 +573,14 @@ sleep(600)
|
||||||
sleep(5)
|
sleep(5)
|
||||||
runtime_env = f"""
|
runtime_env = f"""
|
||||||
{{ "working_dir": test_module.__path__[0] }}""" # noqa: F541
|
{{ "working_dir": test_module.__path__[0] }}""" # noqa: F541
|
||||||
# Execute the following cmd in the second one which should
|
# Execute the following cmd in the second one and ensure that
|
||||||
# fail
|
# it is able to run.
|
||||||
execute_statement = "print('OK')"
|
execute_statement = "print('OK')"
|
||||||
script = driver_script.format(**locals())
|
script = driver_script.format(**locals())
|
||||||
out = run_string_as_driver(script, env)
|
out = run_string_as_driver(script, env)
|
||||||
proc.kill()
|
proc.kill()
|
||||||
proc.wait()
|
proc.wait()
|
||||||
assert out.strip().split()[-1] == "ERROR"
|
assert out.strip().split()[-1] == "OK"
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||||
|
|
|
@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# This version string is incremented to indicate breaking changes in the
|
# This version string is incremented to indicate breaking changes in the
|
||||||
# protocol that require upgrading the client version.
|
# protocol that require upgrading the client version.
|
||||||
CURRENT_PROTOCOL_VERSION = "2021-04-19"
|
CURRENT_PROTOCOL_VERSION = "2021-05-18"
|
||||||
|
|
||||||
|
|
||||||
class RayAPIStub:
|
class RayAPIStub:
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
from ray.util.client import ray
|
from ray.util.client import ray
|
||||||
from ray.util.client.options import validate_options
|
from ray.util.client.options import validate_options
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import grpc
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -421,3 +424,17 @@ def remote_decorator(options: Optional[Dict[str, Any]]):
|
||||||
"either a function or to a class.")
|
"either a function or to a class.")
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClientServerHandle:
|
||||||
|
"""Holds the handles to the registered gRPC servicers and their server."""
|
||||||
|
task_servicer: ray_client_pb2_grpc.RayletDriverServicer
|
||||||
|
data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer
|
||||||
|
logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
|
||||||
|
grpc_server: grpc.Server
|
||||||
|
|
||||||
|
# Add a hook for all the cases that previously
|
||||||
|
# expected simply a gRPC server
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.grpc_server, attr)
|
||||||
|
|
|
@ -63,7 +63,7 @@ class DataClient:
|
||||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
||||||
resp_stream = stub.Datapath(
|
resp_stream = stub.Datapath(
|
||||||
iter(self.request_queue.get, None),
|
iter(self.request_queue.get, None),
|
||||||
metadata=[("client_id", self._client_id)] + self._metadata,
|
metadata=self._metadata,
|
||||||
wait_for_ready=True)
|
wait_for_ready=True)
|
||||||
try:
|
try:
|
||||||
for response in resp_stream:
|
for response in resp_stream:
|
||||||
|
|
|
@ -49,7 +49,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||||
|
|
||||||
def Datapath(self, request_iterator, context):
|
def Datapath(self, request_iterator, context):
|
||||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||||
client_id = metadata["client_id"]
|
client_id = metadata.get("client_id") or ""
|
||||||
if client_id == "":
|
if client_id == "":
|
||||||
logger.error("Client connecting with no client_id")
|
logger.error("Client connecting with no client_id")
|
||||||
return
|
return
|
||||||
|
|
335
python/ray/util/client/server/proxier.py
Normal file
335
python/ray/util/client/server/proxier.py
Normal file
|
@ -0,0 +1,335 @@
|
||||||
|
from concurrent import futures
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import grpc
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from queue import Queue
|
||||||
|
import socket
|
||||||
|
from threading import Thread, Lock
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
from ray.job_config import JobConfig
|
||||||
|
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||||
|
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
|
from ray.util.client.common import (ClientServerHandle,
|
||||||
|
CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS)
|
||||||
|
from ray._private.services import ProcessInfo, start_ray_client_server
|
||||||
|
from ray._private.utils import detect_fate_sharing_support
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CHECK_PROCESS_INTERVAL_S = 30
|
||||||
|
|
||||||
|
MIN_SPECIFIC_SERVER_PORT = 23000
|
||||||
|
MAX_SPECIFIC_SERVER_PORT = 24000
|
||||||
|
|
||||||
|
CHECK_CHANNEL_TIMEOUT_S = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_id_from_context(context: Any) -> str:
|
||||||
|
"""
|
||||||
|
Get `client_id` from gRPC metadata. If the `client_id` is not present,
|
||||||
|
this function logs an error and sets the status_code.
|
||||||
|
"""
|
||||||
|
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||||
|
client_id = metadata.get("client_id") or ""
|
||||||
|
if client_id == "":
|
||||||
|
logger.error("Client connecting with no client_id")
|
||||||
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||||
|
return client_id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpecificServer:
|
||||||
|
port: int
|
||||||
|
process_handle: ProcessInfo
|
||||||
|
channel: "grpc._channel.Channel"
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyManager():
|
||||||
|
def __init__(self, redis_address):
|
||||||
|
self.servers: Dict[str, SpecificServer] = dict()
|
||||||
|
self.server_lock = Lock()
|
||||||
|
self.redis_address = redis_address
|
||||||
|
self._free_ports: List[int] = list(
|
||||||
|
range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT))
|
||||||
|
|
||||||
|
self._check_thread = Thread(target=self._check_processes, daemon=True)
|
||||||
|
self._check_thread.start()
|
||||||
|
|
||||||
|
self.fate_share = bool(detect_fate_sharing_support())
|
||||||
|
|
||||||
|
def _get_unused_port(self) -> int:
|
||||||
|
"""
|
||||||
|
Search for a port in _free_ports that is unused.
|
||||||
|
"""
|
||||||
|
with self.server_lock:
|
||||||
|
num_ports = len(self._free_ports)
|
||||||
|
for _ in range(num_ports):
|
||||||
|
port = self._free_ports.pop(0)
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
try:
|
||||||
|
s.bind(("", port))
|
||||||
|
except OSError:
|
||||||
|
self._free_ports.append(port)
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
s.close()
|
||||||
|
return port
|
||||||
|
raise RuntimeError("Unable to succeed in selecting a random port.")
|
||||||
|
|
||||||
|
def start_specific_server(self, client_id) -> None:
|
||||||
|
"""
|
||||||
|
Start up a RayClient Server for an incoming client to
|
||||||
|
communicate with.
|
||||||
|
"""
|
||||||
|
port = self._get_unused_port()
|
||||||
|
specific_server = SpecificServer(
|
||||||
|
port=port,
|
||||||
|
process_handle=start_ray_client_server(
|
||||||
|
self.redis_address,
|
||||||
|
port,
|
||||||
|
fate_share=self.fate_share,
|
||||||
|
server_type="specific-server"),
|
||||||
|
channel=grpc.insecure_channel(
|
||||||
|
f"localhost:{port}", options=GRPC_OPTIONS))
|
||||||
|
with self.server_lock:
|
||||||
|
self.servers[client_id] = specific_server
|
||||||
|
|
||||||
|
def get_channel(self, client_id: str) -> Optional["grpc._channel.Channel"]:
|
||||||
|
"""
|
||||||
|
Find the gRPC Channel for the given client_id
|
||||||
|
"""
|
||||||
|
client = None
|
||||||
|
with self.server_lock:
|
||||||
|
client = self.servers.get(client_id)
|
||||||
|
if client is None:
|
||||||
|
logger.error(f"Unable to find channel for client: {client_id}")
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
grpc.channel_ready_future(
|
||||||
|
client.channel).result(timeout=CHECK_CHANNEL_TIMEOUT_S)
|
||||||
|
return client.channel
|
||||||
|
except grpc.FutureTimeoutError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _check_processes(self):
|
||||||
|
"""
|
||||||
|
Keeps the internal servers dictionary up-to-date with running servers.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
with self.server_lock:
|
||||||
|
for client_id, specific_server in list(self.servers.items()):
|
||||||
|
poll_result = specific_server.process_handle.process.poll()
|
||||||
|
if poll_result is not None:
|
||||||
|
del self.servers[client_id]
|
||||||
|
# Port is available to use again.
|
||||||
|
self._free_ports.append(specific_server.port)
|
||||||
|
|
||||||
|
time.sleep(CHECK_PROCESS_INTERVAL_S)
|
||||||
|
|
||||||
|
|
||||||
|
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
|
def __init__(self, ray_connect_handler: Callable,
|
||||||
|
proxy_manager: ProxyManager):
|
||||||
|
self.proxy_manager = proxy_manager
|
||||||
|
self.ray_connect_handler = ray_connect_handler
|
||||||
|
|
||||||
|
def _call_inner_function(
|
||||||
|
self, request, context,
|
||||||
|
method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
|
||||||
|
client_id = _get_client_id_from_context(context)
|
||||||
|
chan = self.proxy_manager.get_channel(client_id)
|
||||||
|
if not chan:
|
||||||
|
logger.error(f"Channel for Client: {client_id} not found!")
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
return None
|
||||||
|
|
||||||
|
stub = ray_client_pb2_grpc.RayletDriverStub(chan)
|
||||||
|
return getattr(stub, method)(
|
||||||
|
request, metadata=[("client_id", client_id)])
|
||||||
|
|
||||||
|
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
|
||||||
|
return self._call_inner_function(request, context, "Init")
|
||||||
|
|
||||||
|
def PrepRuntimeEnv(self, request,
|
||||||
|
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
|
||||||
|
return self._call_inner_function(request, context, "PrepRuntimeEnv")
|
||||||
|
|
||||||
|
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
|
||||||
|
return self._call_inner_function(request, context, "KVPut")
|
||||||
|
|
||||||
|
def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
|
||||||
|
return self._call_inner_function(request, context, "KVGet")
|
||||||
|
|
||||||
|
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
|
||||||
|
return self._call_inner_function(request, context, "KVGet")
|
||||||
|
|
||||||
|
def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
|
||||||
|
return self._call_inner_function(request, context, "KVList")
|
||||||
|
|
||||||
|
def KVExists(self, request,
|
||||||
|
context=None) -> ray_client_pb2.KVExistsResponse:
|
||||||
|
return self._call_inner_function(request, context, "KVExists")
|
||||||
|
|
||||||
|
def ClusterInfo(self, request,
|
||||||
|
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||||
|
|
||||||
|
# NOTE: We need to respond to the PING request here to allow the client
|
||||||
|
# to continue with connecting.
|
||||||
|
if request.type == ray_client_pb2.ClusterInfoType.PING:
|
||||||
|
resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({}))
|
||||||
|
return resp
|
||||||
|
return self._call_inner_function(request, context, "ClusterInfo")
|
||||||
|
|
||||||
|
def Terminate(self, req, context=None):
|
||||||
|
return self._call_inner_function(req, context, "Terminate")
|
||||||
|
|
||||||
|
def GetObject(self, request, context=None):
|
||||||
|
return self._call_inner_function(request, context, "GetObject")
|
||||||
|
|
||||||
|
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||||
|
context=None) -> ray_client_pb2.PutResponse:
|
||||||
|
return self._call_inner_function(request, context, "PutObject")
|
||||||
|
|
||||||
|
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||||
|
return self._call_inner_function(request, context, "WaitObject")
|
||||||
|
|
||||||
|
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||||
|
return self._call_inner_function(task, context, "Schedule")
|
||||||
|
|
||||||
|
|
||||||
|
def forward_streaming_requests(grpc_input_generator: Iterator[Any],
|
||||||
|
output_queue: "Queue") -> None:
|
||||||
|
"""
|
||||||
|
Forwards streaming requests from the grpc_input_generator into the
|
||||||
|
output_queue.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for req in grpc_input_generator:
|
||||||
|
output_queue.put(req)
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
logger.debug("closing dataservicer reader thread "
|
||||||
|
f"grpc error reading request_iterator: {e}")
|
||||||
|
finally:
|
||||||
|
# Set the sentinel value for the output_queue
|
||||||
|
output_queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_runtime_init_req(req: ray_client_pb2.InitRequest
|
||||||
|
) -> Tuple[ray_client_pb2.InitRequest, JobConfig]:
|
||||||
|
"""
|
||||||
|
Extract JobConfig and possibly mutate InitRequest before it is passed to
|
||||||
|
the specific RayClient Server.
|
||||||
|
"""
|
||||||
|
job_config = JobConfig()
|
||||||
|
if req.job_config:
|
||||||
|
import pickle
|
||||||
|
job_config = pickle.loads(req.job_config)
|
||||||
|
|
||||||
|
return (req, job_config)
|
||||||
|
|
||||||
|
|
||||||
|
class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||||
|
def __init__(self, proxy_manager: ProxyManager):
|
||||||
|
self.proxy_manager = proxy_manager
|
||||||
|
|
||||||
|
def Datapath(self, request_iterator, context):
|
||||||
|
client_id = _get_client_id_from_context(context)
|
||||||
|
if client_id == "":
|
||||||
|
return
|
||||||
|
logger.debug(f"New data connection from client {client_id}: ")
|
||||||
|
|
||||||
|
init_req = next(request_iterator)
|
||||||
|
init_type = init_req.WhichOneof("type")
|
||||||
|
assert init_type == "init", ("Received initial message of type "
|
||||||
|
f"{init_type}, not 'init'.")
|
||||||
|
|
||||||
|
modified_init_req, job_config = prepare_runtime_init_req(init_req.init)
|
||||||
|
init_req.init.CopyFrom(modified_init_req)
|
||||||
|
queue = Queue()
|
||||||
|
queue.put(init_req)
|
||||||
|
|
||||||
|
self.proxy_manager.start_specific_server(client_id)
|
||||||
|
|
||||||
|
channel = self.proxy_manager.get_channel(client_id)
|
||||||
|
if channel is None:
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
return None
|
||||||
|
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||||
|
thread = Thread(
|
||||||
|
target=forward_streaming_requests,
|
||||||
|
args=(request_iterator, queue),
|
||||||
|
daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
resp_stream = stub.Datapath(
|
||||||
|
iter(queue.get, None), metadata=[("client_id", client_id)])
|
||||||
|
for resp in resp_stream:
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||||
|
def __init__(self, proxy_manager: ProxyManager):
|
||||||
|
super().__init__()
|
||||||
|
self.proxy_manager = proxy_manager
|
||||||
|
|
||||||
|
def Logstream(self, request_iterator, context):
|
||||||
|
client_id = _get_client_id_from_context(context)
|
||||||
|
if client_id == "":
|
||||||
|
return
|
||||||
|
logger.debug(f"New data connection from client {client_id}: ")
|
||||||
|
|
||||||
|
channel = None
|
||||||
|
for i in range(10):
|
||||||
|
# TODO(ilr) Ensure LogClient starts after startup has happened.
|
||||||
|
# This will remove the need for retries here.
|
||||||
|
channel = self.proxy_manager.get_channel(client_id)
|
||||||
|
|
||||||
|
if channel is not None:
|
||||||
|
break
|
||||||
|
logger.warning(
|
||||||
|
f"Retrying Logstream connection. {i+1} attempts failed.")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
if channel is None:
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
return None
|
||||||
|
|
||||||
|
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
||||||
|
queue = Queue()
|
||||||
|
thread = Thread(
|
||||||
|
target=forward_streaming_requests,
|
||||||
|
args=(request_iterator, queue),
|
||||||
|
daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
resp_stream = stub.Logstream(
|
||||||
|
iter(queue.get, None), metadata=[("client_id", client_id)])
|
||||||
|
for resp in resp_stream:
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
def serve_proxier(connection_str: str, redis_address: str):
|
||||||
|
server = grpc.server(
|
||||||
|
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
|
||||||
|
options=GRPC_OPTIONS)
|
||||||
|
proxy_manager = ProxyManager(redis_address)
|
||||||
|
task_servicer = RayletServicerProxy(None, proxy_manager)
|
||||||
|
data_servicer = DataServicerProxy(proxy_manager)
|
||||||
|
logs_servicer = LogstreamServicerProxy(proxy_manager)
|
||||||
|
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||||
|
task_servicer, server)
|
||||||
|
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||||
|
data_servicer, server)
|
||||||
|
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||||
|
logs_servicer, server)
|
||||||
|
server.add_insecure_port(connection_str)
|
||||||
|
server.start()
|
||||||
|
return ClientServerHandle(
|
||||||
|
task_servicer=task_servicer,
|
||||||
|
data_servicer=data_servicer,
|
||||||
|
logs_servicer=logs_servicer,
|
||||||
|
grpc_server=server,
|
||||||
|
)
|
|
@ -3,7 +3,6 @@ from concurrent import futures
|
||||||
import grpc
|
import grpc
|
||||||
import base64
|
import base64
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
|
@ -22,7 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
import time
|
import time
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
from ray.util.client.common import (GRPC_OPTIONS, CLIENT_SERVER_MAX_THREADS)
|
from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS,
|
||||||
|
CLIENT_SERVER_MAX_THREADS)
|
||||||
|
from ray.util.client.server.proxier import serve_proxier
|
||||||
from ray.util.client.server.server_pickler import convert_from_arg
|
from ray.util.client.server.server_pickler import convert_from_arg
|
||||||
from ray.util.client.server.server_pickler import dumps_from_server
|
from ray.util.client.server.server_pickler import dumps_from_server
|
||||||
from ray.util.client.server.server_pickler import loads_from_client
|
from ray.util.client.server.server_pickler import loads_from_client
|
||||||
|
@ -34,6 +35,8 @@ from ray._private.client_mode_hook import disable_client_hook
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TIMEOUT_FOR_SPECIFIC_SERVER_S = 30
|
||||||
|
|
||||||
|
|
||||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
def __init__(self, ray_connect_handler: Callable):
|
def __init__(self, ray_connect_handler: Callable):
|
||||||
|
@ -170,6 +173,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
data = ray.is_initialized()
|
data = ray.is_initialized()
|
||||||
elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE:
|
elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE:
|
||||||
data = ray.timeline()
|
data = ray.timeline()
|
||||||
|
elif request.type == ray_client_pb2.ClusterInfoType.PING:
|
||||||
|
data = {}
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unsupported cluster info type")
|
raise TypeError("Unsupported cluster info type")
|
||||||
return json.dumps(data)
|
return json.dumps(data)
|
||||||
|
@ -560,20 +565,6 @@ def decode_options(
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClientServerHandle:
|
|
||||||
"""Holds the handles to the registered gRPC servicers and their server."""
|
|
||||||
task_servicer: RayletServicer
|
|
||||||
data_servicer: DataServicer
|
|
||||||
logs_servicer: LogstreamServicer
|
|
||||||
grpc_server: grpc.Server
|
|
||||||
|
|
||||||
# Add a hook for all the cases that previously
|
|
||||||
# expected simply a gRPC server
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.grpc_server, attr)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(connection_str, ray_connect_handler=None):
|
def serve(connection_str, ray_connect_handler=None):
|
||||||
def default_connect_handler(job_config: JobConfig = None):
|
def default_connect_handler(job_config: JobConfig = None):
|
||||||
with disable_client_hook():
|
with disable_client_hook():
|
||||||
|
@ -666,6 +657,11 @@ def main():
|
||||||
"--host", type=str, default="0.0.0.0", help="Host IP to bind to")
|
"--host", type=str, default="0.0.0.0", help="Host IP to bind to")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-p", "--port", type=int, default=50051, help="Port to bind to")
|
"-p", "--port", type=int, default=50051, help="Port to bind to")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
choices=["proxy", "legacy", "specific-server"],
|
||||||
|
default="proxy")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--redis-address",
|
"--redis-address",
|
||||||
required=False,
|
required=False,
|
||||||
|
@ -689,8 +685,13 @@ def main():
|
||||||
|
|
||||||
hostport = "%s:%d" % (args.host, args.port)
|
hostport = "%s:%d" % (args.host, args.port)
|
||||||
logger.info(f"Starting Ray Client server on {hostport}")
|
logger.info(f"Starting Ray Client server on {hostport}")
|
||||||
server = serve(hostport, ray_connect_handler)
|
if args.mode == "proxy":
|
||||||
|
server = serve_proxier(hostport, args.redis_address)
|
||||||
|
else:
|
||||||
|
server = serve(hostport, ray_connect_handler)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
|
||||||
while True:
|
while True:
|
||||||
health_report = {
|
health_report = {
|
||||||
"time": time.time(),
|
"time": time.time(),
|
||||||
|
@ -706,6 +707,18 @@ def main():
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
if args.mode == "specific-server":
|
||||||
|
if server.data_servicer.num_clients > 0:
|
||||||
|
idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
|
||||||
|
else:
|
||||||
|
idle_checks_remaining -= 1
|
||||||
|
if idle_checks_remaining == 0:
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
if (idle_checks_remaining % 5 == 0 and idle_checks_remaining !=
|
||||||
|
TIMEOUT_FOR_SPECIFIC_SERVER_S):
|
||||||
|
logger.info(
|
||||||
|
f"{idle_checks_remaining} idle checks before shutdown."
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
|
|
@ -77,11 +77,12 @@ class Worker:
|
||||||
at least once. For infinite retries, catch the ConnectionError
|
at least once. For infinite retries, catch the ConnectionError
|
||||||
exception.
|
exception.
|
||||||
"""
|
"""
|
||||||
self.metadata = metadata if metadata else []
|
self._client_id = make_client_id()
|
||||||
|
self.metadata = [("client_id", self._client_id)] + (metadata if
|
||||||
|
metadata else [])
|
||||||
self.channel = None
|
self.channel = None
|
||||||
self.server = None
|
self.server = None
|
||||||
self._conn_state = grpc.ChannelConnectivity.IDLE
|
self._conn_state = grpc.ChannelConnectivity.IDLE
|
||||||
self._client_id = make_client_id()
|
|
||||||
self._converted: Dict[str, ClientStub] = {}
|
self._converted: Dict[str, ClientStub] = {}
|
||||||
|
|
||||||
if secure:
|
if secure:
|
||||||
|
@ -439,8 +440,7 @@ class Worker:
|
||||||
"""
|
"""
|
||||||
if self.server is not None:
|
if self.server is not None:
|
||||||
logger.debug("Pinging server.")
|
logger.debug("Pinging server.")
|
||||||
result = self.get_cluster_info(
|
result = self.get_cluster_info(ray_client_pb2.ClusterInfoType.PING)
|
||||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
|
||||||
return result is not None
|
return result is not None
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -157,6 +157,7 @@ message ClusterInfoType {
|
||||||
AVAILABLE_RESOURCES = 3;
|
AVAILABLE_RESOURCES = 3;
|
||||||
RUNTIME_CONTEXT = 4;
|
RUNTIME_CONTEXT = 4;
|
||||||
TIMELINE = 5;
|
TIMELINE = 5;
|
||||||
|
PING = 6;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue