mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[client] One Driver per RayClient Server (#15875)
This commit is contained in:
parent
c636bc3065
commit
97d1414f23
12 changed files with 402 additions and 32 deletions
|
@ -1868,7 +1868,8 @@ def start_ray_client_server(redis_address,
|
|||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
redis_password=None,
|
||||
fate_share=None):
|
||||
fate_share=None,
|
||||
server_type="proxy"):
|
||||
"""Run the server process of the Ray client.
|
||||
|
||||
Args:
|
||||
|
@ -1878,6 +1879,7 @@ def start_ray_client_server(redis_address,
|
|||
stderr_file: A file handle opened for writing to redirect stderr to. If
|
||||
no redirection should happen, then this should be None.
|
||||
redis_password (str): The password of the redis server.
|
||||
server_type (str): Whether to start the proxy version of Ray Client.
|
||||
|
||||
Returns:
|
||||
ProcessInfo for the process that was started.
|
||||
|
@ -1885,7 +1887,7 @@ def start_ray_client_server(redis_address,
|
|||
command = [
|
||||
sys.executable, "-m", "ray.util.client.server",
|
||||
"--redis-address=" + str(redis_address),
|
||||
"--port=" + str(ray_client_server_port)
|
||||
"--port=" + str(ray_client_server_port), "--mode=" + server_type
|
||||
]
|
||||
if redis_password:
|
||||
command.append("--redis-password=" + redis_password)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
@ -26,6 +27,7 @@ def ray_client_instance():
|
|||
subprocess.check_output(["ray", "stop", "--force"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "linux", reason="Buggy on MacOS + Windows")
|
||||
def test_ray_client(ray_client_instance):
|
||||
ray.util.connect(ray_client_instance)
|
||||
|
||||
|
|
|
@ -113,7 +113,7 @@ def test_kill_cancel_metadata(ray_start_regular):
|
|||
pass
|
||||
|
||||
def mock_terminate(term, metadata):
|
||||
raise MetadataIsCorrectlyPassedException(metadata[0][0])
|
||||
raise MetadataIsCorrectlyPassedException(metadata[1][0])
|
||||
|
||||
# Mock stub's Terminate method to raise an exception.
|
||||
ray.api.worker.server.Terminate = mock_terminate
|
||||
|
|
|
@ -523,11 +523,11 @@ sleep(600)
|
|||
# waiting it to be up
|
||||
sleep(5)
|
||||
runtime_env = f"""{{ "working_dir": "{working_dir}" }}"""
|
||||
# Execute the second one which should trigger an error
|
||||
# Execute the second one which should work because Ray Client servers.
|
||||
execute_statement = "print(sum(ray.get([run_test.remote()] * 1000)))"
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script, env)
|
||||
assert out.strip().split()[-1] == "ERROR"
|
||||
assert out.strip().split()[-1] == "1000"
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
@ -573,14 +573,14 @@ sleep(600)
|
|||
sleep(5)
|
||||
runtime_env = f"""
|
||||
{{ "working_dir": test_module.__path__[0] }}""" # noqa: F541
|
||||
# Execute the following cmd in the second one which should
|
||||
# fail
|
||||
# Execute the following cmd in the second one and ensure that
|
||||
# it is able to run.
|
||||
execute_statement = "print('OK')"
|
||||
script = driver_script.format(**locals())
|
||||
out = run_string_as_driver(script, env)
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
assert out.strip().split()[-1] == "ERROR"
|
||||
assert out.strip().split()[-1] == "OK"
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "Fail to create temp dir.")
|
||||
|
|
|
@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# This version string is incremented to indicate breaking changes in the
|
||||
# protocol that require upgrading the client version.
|
||||
CURRENT_PROTOCOL_VERSION = "2021-04-19"
|
||||
CURRENT_PROTOCOL_VERSION = "2021-05-18"
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client import ray
|
||||
from ray.util.client.options import validate_options
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from dataclasses import dataclass
|
||||
import grpc
|
||||
import os
|
||||
import uuid
|
||||
import inspect
|
||||
|
@ -421,3 +424,17 @@ def remote_decorator(options: Optional[Dict[str, Any]]):
|
|||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientServerHandle:
|
||||
"""Holds the handles to the registered gRPC servicers and their server."""
|
||||
task_servicer: ray_client_pb2_grpc.RayletDriverServicer
|
||||
data_servicer: ray_client_pb2_grpc.RayletDataStreamerServicer
|
||||
logs_servicer: ray_client_pb2_grpc.RayletLogStreamerServicer
|
||||
grpc_server: grpc.Server
|
||||
|
||||
# Add a hook for all the cases that previously
|
||||
# expected simply a gRPC server
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.grpc_server, attr)
|
||||
|
|
|
@ -63,7 +63,7 @@ class DataClient:
|
|||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=[("client_id", self._client_id)] + self._metadata,
|
||||
metadata=self._metadata,
|
||||
wait_for_ready=True)
|
||||
try:
|
||||
for response in resp_stream:
|
||||
|
|
|
@ -49,7 +49,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata["client_id"]
|
||||
client_id = metadata.get("client_id") or ""
|
||||
if client_id == "":
|
||||
logger.error("Client connecting with no client_id")
|
||||
return
|
||||
|
|
335
python/ray/util/client/server/proxier.py
Normal file
335
python/ray/util/client/server/proxier.py
Normal file
|
@ -0,0 +1,335 @@
|
|||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
import grpc
|
||||
import logging
|
||||
import json
|
||||
from queue import Queue
|
||||
import socket
|
||||
from threading import Thread, Lock
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from ray.job_config import JobConfig
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.util.client.common import (ClientServerHandle,
|
||||
CLIENT_SERVER_MAX_THREADS, GRPC_OPTIONS)
|
||||
from ray._private.services import ProcessInfo, start_ray_client_server
|
||||
from ray._private.utils import detect_fate_sharing_support
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHECK_PROCESS_INTERVAL_S = 30
|
||||
|
||||
MIN_SPECIFIC_SERVER_PORT = 23000
|
||||
MAX_SPECIFIC_SERVER_PORT = 24000
|
||||
|
||||
CHECK_CHANNEL_TIMEOUT_S = 5
|
||||
|
||||
|
||||
def _get_client_id_from_context(context: Any) -> str:
|
||||
"""
|
||||
Get `client_id` from gRPC metadata. If the `client_id` is not present,
|
||||
this function logs an error and sets the status_code.
|
||||
"""
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata.get("client_id") or ""
|
||||
if client_id == "":
|
||||
logger.error("Client connecting with no client_id")
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
return client_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecificServer:
|
||||
port: int
|
||||
process_handle: ProcessInfo
|
||||
channel: "grpc._channel.Channel"
|
||||
|
||||
|
||||
class ProxyManager():
|
||||
def __init__(self, redis_address):
|
||||
self.servers: Dict[str, SpecificServer] = dict()
|
||||
self.server_lock = Lock()
|
||||
self.redis_address = redis_address
|
||||
self._free_ports: List[int] = list(
|
||||
range(MIN_SPECIFIC_SERVER_PORT, MAX_SPECIFIC_SERVER_PORT))
|
||||
|
||||
self._check_thread = Thread(target=self._check_processes, daemon=True)
|
||||
self._check_thread.start()
|
||||
|
||||
self.fate_share = bool(detect_fate_sharing_support())
|
||||
|
||||
def _get_unused_port(self) -> int:
|
||||
"""
|
||||
Search for a port in _free_ports that is unused.
|
||||
"""
|
||||
with self.server_lock:
|
||||
num_ports = len(self._free_ports)
|
||||
for _ in range(num_ports):
|
||||
port = self._free_ports.pop(0)
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.bind(("", port))
|
||||
except OSError:
|
||||
self._free_ports.append(port)
|
||||
continue
|
||||
finally:
|
||||
s.close()
|
||||
return port
|
||||
raise RuntimeError("Unable to succeed in selecting a random port.")
|
||||
|
||||
def start_specific_server(self, client_id) -> None:
|
||||
"""
|
||||
Start up a RayClient Server for an incoming client to
|
||||
communicate with.
|
||||
"""
|
||||
port = self._get_unused_port()
|
||||
specific_server = SpecificServer(
|
||||
port=port,
|
||||
process_handle=start_ray_client_server(
|
||||
self.redis_address,
|
||||
port,
|
||||
fate_share=self.fate_share,
|
||||
server_type="specific-server"),
|
||||
channel=grpc.insecure_channel(
|
||||
f"localhost:{port}", options=GRPC_OPTIONS))
|
||||
with self.server_lock:
|
||||
self.servers[client_id] = specific_server
|
||||
|
||||
def get_channel(self, client_id: str) -> Optional["grpc._channel.Channel"]:
|
||||
"""
|
||||
Find the gRPC Channel for the given client_id
|
||||
"""
|
||||
client = None
|
||||
with self.server_lock:
|
||||
client = self.servers.get(client_id)
|
||||
if client is None:
|
||||
logger.error(f"Unable to find channel for client: {client_id}")
|
||||
return None
|
||||
try:
|
||||
grpc.channel_ready_future(
|
||||
client.channel).result(timeout=CHECK_CHANNEL_TIMEOUT_S)
|
||||
return client.channel
|
||||
except grpc.FutureTimeoutError:
|
||||
return None
|
||||
|
||||
def _check_processes(self):
|
||||
"""
|
||||
Keeps the internal servers dictionary up-to-date with running servers.
|
||||
"""
|
||||
while True:
|
||||
with self.server_lock:
|
||||
for client_id, specific_server in list(self.servers.items()):
|
||||
poll_result = specific_server.process_handle.process.poll()
|
||||
if poll_result is not None:
|
||||
del self.servers[client_id]
|
||||
# Port is available to use again.
|
||||
self._free_ports.append(specific_server.port)
|
||||
|
||||
time.sleep(CHECK_PROCESS_INTERVAL_S)
|
||||
|
||||
|
||||
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, ray_connect_handler: Callable,
|
||||
proxy_manager: ProxyManager):
|
||||
self.proxy_manager = proxy_manager
|
||||
self.ray_connect_handler = ray_connect_handler
|
||||
|
||||
def _call_inner_function(
|
||||
self, request, context,
|
||||
method: str) -> Optional[ray_client_pb2_grpc.RayletDriverStub]:
|
||||
client_id = _get_client_id_from_context(context)
|
||||
chan = self.proxy_manager.get_channel(client_id)
|
||||
if not chan:
|
||||
logger.error(f"Channel for Client: {client_id} not found!")
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
return None
|
||||
|
||||
stub = ray_client_pb2_grpc.RayletDriverStub(chan)
|
||||
return getattr(stub, method)(
|
||||
request, metadata=[("client_id", client_id)])
|
||||
|
||||
def Init(self, request, context=None) -> ray_client_pb2.InitResponse:
|
||||
return self._call_inner_function(request, context, "Init")
|
||||
|
||||
def PrepRuntimeEnv(self, request,
|
||||
context=None) -> ray_client_pb2.PrepRuntimeEnvResponse:
|
||||
return self._call_inner_function(request, context, "PrepRuntimeEnv")
|
||||
|
||||
def KVPut(self, request, context=None) -> ray_client_pb2.KVPutResponse:
|
||||
return self._call_inner_function(request, context, "KVPut")
|
||||
|
||||
def KVGet(self, request, context=None) -> ray_client_pb2.KVGetResponse:
|
||||
return self._call_inner_function(request, context, "KVGet")
|
||||
|
||||
def KVDel(self, request, context=None) -> ray_client_pb2.KVDelResponse:
|
||||
return self._call_inner_function(request, context, "KVGet")
|
||||
|
||||
def KVList(self, request, context=None) -> ray_client_pb2.KVListResponse:
|
||||
return self._call_inner_function(request, context, "KVList")
|
||||
|
||||
def KVExists(self, request,
|
||||
context=None) -> ray_client_pb2.KVExistsResponse:
|
||||
return self._call_inner_function(request, context, "KVExists")
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
|
||||
# NOTE: We need to respond to the PING request here to allow the client
|
||||
# to continue with connecting.
|
||||
if request.type == ray_client_pb2.ClusterInfoType.PING:
|
||||
resp = ray_client_pb2.ClusterInfoResponse(json=json.dumps({}))
|
||||
return resp
|
||||
return self._call_inner_function(request, context, "ClusterInfo")
|
||||
|
||||
def Terminate(self, req, context=None):
|
||||
return self._call_inner_function(req, context, "Terminate")
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
return self._call_inner_function(request, context, "GetObject")
|
||||
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
return self._call_inner_function(request, context, "PutObject")
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
return self._call_inner_function(request, context, "WaitObject")
|
||||
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
return self._call_inner_function(task, context, "Schedule")
|
||||
|
||||
|
||||
def forward_streaming_requests(grpc_input_generator: Iterator[Any],
|
||||
output_queue: "Queue") -> None:
|
||||
"""
|
||||
Forwards streaming requests from the grpc_input_generator into the
|
||||
output_queue.
|
||||
"""
|
||||
try:
|
||||
for req in grpc_input_generator:
|
||||
output_queue.put(req)
|
||||
except grpc.RpcError as e:
|
||||
logger.debug("closing dataservicer reader thread "
|
||||
f"grpc error reading request_iterator: {e}")
|
||||
finally:
|
||||
# Set the sentinel value for the output_queue
|
||||
output_queue.put(None)
|
||||
|
||||
|
||||
def prepare_runtime_init_req(req: ray_client_pb2.InitRequest
|
||||
) -> Tuple[ray_client_pb2.InitRequest, JobConfig]:
|
||||
"""
|
||||
Extract JobConfig and possibly mutate InitRequest before it is passed to
|
||||
the specific RayClient Server.
|
||||
"""
|
||||
job_config = JobConfig()
|
||||
if req.job_config:
|
||||
import pickle
|
||||
job_config = pickle.loads(req.job_config)
|
||||
|
||||
return (req, job_config)
|
||||
|
||||
|
||||
class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, proxy_manager: ProxyManager):
|
||||
self.proxy_manager = proxy_manager
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
client_id = _get_client_id_from_context(context)
|
||||
if client_id == "":
|
||||
return
|
||||
logger.debug(f"New data connection from client {client_id}: ")
|
||||
|
||||
init_req = next(request_iterator)
|
||||
init_type = init_req.WhichOneof("type")
|
||||
assert init_type == "init", ("Received initial message of type "
|
||||
f"{init_type}, not 'init'.")
|
||||
|
||||
modified_init_req, job_config = prepare_runtime_init_req(init_req.init)
|
||||
init_req.init.CopyFrom(modified_init_req)
|
||||
queue = Queue()
|
||||
queue.put(init_req)
|
||||
|
||||
self.proxy_manager.start_specific_server(client_id)
|
||||
|
||||
channel = self.proxy_manager.get_channel(client_id)
|
||||
if channel is None:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
return None
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||
thread = Thread(
|
||||
target=forward_streaming_requests,
|
||||
args=(request_iterator, queue),
|
||||
daemon=True)
|
||||
thread.start()
|
||||
|
||||
resp_stream = stub.Datapath(
|
||||
iter(queue.get, None), metadata=[("client_id", client_id)])
|
||||
for resp in resp_stream:
|
||||
yield resp
|
||||
|
||||
|
||||
class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
def __init__(self, proxy_manager: ProxyManager):
|
||||
super().__init__()
|
||||
self.proxy_manager = proxy_manager
|
||||
|
||||
def Logstream(self, request_iterator, context):
|
||||
client_id = _get_client_id_from_context(context)
|
||||
if client_id == "":
|
||||
return
|
||||
logger.debug(f"New data connection from client {client_id}: ")
|
||||
|
||||
channel = None
|
||||
for i in range(10):
|
||||
# TODO(ilr) Ensure LogClient starts after startup has happened.
|
||||
# This will remove the need for retries here.
|
||||
channel = self.proxy_manager.get_channel(client_id)
|
||||
|
||||
if channel is not None:
|
||||
break
|
||||
logger.warning(
|
||||
f"Retrying Logstream connection. {i+1} attempts failed.")
|
||||
time.sleep(5)
|
||||
|
||||
if channel is None:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
return None
|
||||
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
||||
queue = Queue()
|
||||
thread = Thread(
|
||||
target=forward_streaming_requests,
|
||||
args=(request_iterator, queue),
|
||||
daemon=True)
|
||||
thread.start()
|
||||
|
||||
resp_stream = stub.Logstream(
|
||||
iter(queue.get, None), metadata=[("client_id", client_id)])
|
||||
for resp in resp_stream:
|
||||
yield resp
|
||||
|
||||
|
||||
def serve_proxier(connection_str: str, redis_address: str):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
|
||||
options=GRPC_OPTIONS)
|
||||
proxy_manager = ProxyManager(redis_address)
|
||||
task_servicer = RayletServicerProxy(None, proxy_manager)
|
||||
data_servicer = DataServicerProxy(proxy_manager)
|
||||
logs_servicer = LogstreamServicerProxy(proxy_manager)
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
data_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||
logs_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
server.start()
|
||||
return ClientServerHandle(
|
||||
task_servicer=task_servicer,
|
||||
data_servicer=data_servicer,
|
||||
logs_servicer=logs_servicer,
|
||||
grpc_server=server,
|
||||
)
|
|
@ -3,7 +3,6 @@ from concurrent import futures
|
|||
import grpc
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import queue
|
||||
|
||||
|
@ -22,7 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
|||
import time
|
||||
import inspect
|
||||
import json
|
||||
from ray.util.client.common import (GRPC_OPTIONS, CLIENT_SERVER_MAX_THREADS)
|
||||
from ray.util.client.common import (ClientServerHandle, GRPC_OPTIONS,
|
||||
CLIENT_SERVER_MAX_THREADS)
|
||||
from ray.util.client.server.proxier import serve_proxier
|
||||
from ray.util.client.server.server_pickler import convert_from_arg
|
||||
from ray.util.client.server.server_pickler import dumps_from_server
|
||||
from ray.util.client.server.server_pickler import loads_from_client
|
||||
|
@ -34,6 +35,8 @@ from ray._private.client_mode_hook import disable_client_hook
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TIMEOUT_FOR_SPECIFIC_SERVER_S = 30
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, ray_connect_handler: Callable):
|
||||
|
@ -170,6 +173,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
|||
data = ray.is_initialized()
|
||||
elif request.type == ray_client_pb2.ClusterInfoType.TIMELINE:
|
||||
data = ray.timeline()
|
||||
elif request.type == ray_client_pb2.ClusterInfoType.PING:
|
||||
data = {}
|
||||
else:
|
||||
raise TypeError("Unsupported cluster info type")
|
||||
return json.dumps(data)
|
||||
|
@ -560,20 +565,6 @@ def decode_options(
|
|||
return opts
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientServerHandle:
|
||||
"""Holds the handles to the registered gRPC servicers and their server."""
|
||||
task_servicer: RayletServicer
|
||||
data_servicer: DataServicer
|
||||
logs_servicer: LogstreamServicer
|
||||
grpc_server: grpc.Server
|
||||
|
||||
# Add a hook for all the cases that previously
|
||||
# expected simply a gRPC server
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.grpc_server, attr)
|
||||
|
||||
|
||||
def serve(connection_str, ray_connect_handler=None):
|
||||
def default_connect_handler(job_config: JobConfig = None):
|
||||
with disable_client_hook():
|
||||
|
@ -666,6 +657,11 @@ def main():
|
|||
"--host", type=str, default="0.0.0.0", help="Host IP to bind to")
|
||||
parser.add_argument(
|
||||
"-p", "--port", type=int, default=50051, help="Port to bind to")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["proxy", "legacy", "specific-server"],
|
||||
default="proxy")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=False,
|
||||
|
@ -689,8 +685,13 @@ def main():
|
|||
|
||||
hostport = "%s:%d" % (args.host, args.port)
|
||||
logger.info(f"Starting Ray Client server on {hostport}")
|
||||
server = serve(hostport, ray_connect_handler)
|
||||
if args.mode == "proxy":
|
||||
server = serve_proxier(hostport, args.redis_address)
|
||||
else:
|
||||
server = serve(hostport, ray_connect_handler)
|
||||
|
||||
try:
|
||||
idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
|
||||
while True:
|
||||
health_report = {
|
||||
"time": time.time(),
|
||||
|
@ -706,6 +707,18 @@ def main():
|
|||
logger.exception(e)
|
||||
|
||||
time.sleep(1)
|
||||
if args.mode == "specific-server":
|
||||
if server.data_servicer.num_clients > 0:
|
||||
idle_checks_remaining = TIMEOUT_FOR_SPECIFIC_SERVER_S
|
||||
else:
|
||||
idle_checks_remaining -= 1
|
||||
if idle_checks_remaining == 0:
|
||||
raise KeyboardInterrupt()
|
||||
if (idle_checks_remaining % 5 == 0 and idle_checks_remaining !=
|
||||
TIMEOUT_FOR_SPECIFIC_SERVER_S):
|
||||
logger.info(
|
||||
f"{idle_checks_remaining} idle checks before shutdown."
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
|
|
@ -77,11 +77,12 @@ class Worker:
|
|||
at least once. For infinite retries, catch the ConnectionError
|
||||
exception.
|
||||
"""
|
||||
self.metadata = metadata if metadata else []
|
||||
self._client_id = make_client_id()
|
||||
self.metadata = [("client_id", self._client_id)] + (metadata if
|
||||
metadata else [])
|
||||
self.channel = None
|
||||
self.server = None
|
||||
self._conn_state = grpc.ChannelConnectivity.IDLE
|
||||
self._client_id = make_client_id()
|
||||
self._converted: Dict[str, ClientStub] = {}
|
||||
|
||||
if secure:
|
||||
|
@ -439,8 +440,7 @@ class Worker:
|
|||
"""
|
||||
if self.server is not None:
|
||||
logger.debug("Pinging server.")
|
||||
result = self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
result = self.get_cluster_info(ray_client_pb2.ClusterInfoType.PING)
|
||||
return result is not None
|
||||
return False
|
||||
|
||||
|
|
|
@ -157,6 +157,7 @@ message ClusterInfoType {
|
|||
AVAILABLE_RESOURCES = 3;
|
||||
RUNTIME_CONTEXT = 4;
|
||||
TIMELINE = 5;
|
||||
PING = 6;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue