mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Client][Proxy] Prevent Logstream from Timing Out when Delays in DataClient (#16180)
This commit is contained in:
parent
fa292a4edf
commit
22bd7cebeb
3 changed files with 124 additions and 50 deletions
|
@ -32,6 +32,7 @@ def test_proxy_manager_lifecycle(shutdown_only):
|
|||
pm._free_ports = [45000, 45001]
|
||||
client = "client1"
|
||||
|
||||
pm.create_specific_server(client)
|
||||
assert pm.start_specific_server(client, JobConfig())
|
||||
# Channel should be ready and corresponding to an existing server
|
||||
grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)
|
||||
|
@ -39,7 +40,7 @@ def test_proxy_manager_lifecycle(shutdown_only):
|
|||
proc = pm._get_server_for_client(client)
|
||||
assert proc.port == 45000
|
||||
|
||||
proc.process_handle().process.wait(10)
|
||||
proc.process_handle_future.result().process.wait(10)
|
||||
# Wait for reconcile loop
|
||||
time.sleep(2)
|
||||
|
||||
|
@ -63,6 +64,7 @@ def test_proxy_manager_bad_startup(shutdown_only):
|
|||
pm._free_ports = [46000, 46001]
|
||||
client = "client1"
|
||||
|
||||
pm.create_specific_server(client)
|
||||
assert not pm.start_specific_server(
|
||||
client,
|
||||
JobConfig(
|
||||
|
@ -122,6 +124,40 @@ def test_correct_num_clients(call_ray_start):
|
|||
run_string_as_driver(check_we_are_second.format(num_clients=1))
|
||||
|
||||
|
||||
check_connection = """
|
||||
import ray
|
||||
ray.client("localhost:25010").connect()
|
||||
assert ray.util.client.ray.worker.log_client.log_thread.is_alive()
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="PSUtil does not work the same on windows.")
|
||||
def test_delay_in_rewriting_environment(shutdown_only):
|
||||
"""
|
||||
Check that a delay in `ray_client_server_env_prep` does not break
|
||||
a Client connecting.
|
||||
"""
|
||||
proxier.LOGSTREAM_RETRIES = 3
|
||||
proxier.LOGSTREAM_RETRY_INTERVAL_SEC = 1
|
||||
ray_instance = ray.init()
|
||||
|
||||
def delay_in_rewrite(input: JobConfig):
|
||||
import time
|
||||
time.sleep(6)
|
||||
return input
|
||||
|
||||
proxier.ray_client_server_env_prep = delay_in_rewrite
|
||||
|
||||
server = proxier.serve_proxier("localhost:25010",
|
||||
ray_instance["redis_address"],
|
||||
ray_instance["session_dir"])
|
||||
|
||||
run_string_as_driver(check_connection)
|
||||
server.stop(0)
|
||||
|
||||
|
||||
def test_prepare_runtime_init_req_fails():
|
||||
"""
|
||||
Check that a connection that is initiated with a non-Init request
|
||||
|
|
|
@ -31,6 +31,9 @@ MAX_SPECIFIC_SERVER_PORT = 24000
|
|||
|
||||
CHECK_CHANNEL_TIMEOUT_S = 10
|
||||
|
||||
LOGSTREAM_RETRIES = 5
|
||||
LOGSTREAM_RETRY_INTERVAL_SEC = 2
|
||||
|
||||
|
||||
def _get_client_id_from_context(context: Any) -> str:
|
||||
"""
|
||||
|
@ -55,11 +58,34 @@ class SpecificServer:
|
|||
"""
|
||||
Wait for the server to actually start up.
|
||||
"""
|
||||
self.process_handle_future.result(timeout=timeout)
|
||||
res = self.process_handle_future.result(timeout=timeout)
|
||||
if res is None:
|
||||
# This is only set to none when server creation specifically fails.
|
||||
raise RuntimeError("Server startup failed.")
|
||||
|
||||
def poll(self) -> Optional[int]:
|
||||
"""Check if the process has exited."""
|
||||
try:
|
||||
proc = self.process_handle_future.result(timeout=0.1)
|
||||
if proc is not None:
|
||||
return proc.process.poll()
|
||||
except futures.TimeoutError:
|
||||
return
|
||||
|
||||
def process_handle(self) -> ProcessInfo:
|
||||
return self.process_handle_future.result()
|
||||
def kill(self) -> None:
|
||||
"""Try to send a KILL signal to the process."""
|
||||
try:
|
||||
proc = self.process_handle_future.result(timeout=0.1)
|
||||
if proc is not None:
|
||||
proc.process.kill()
|
||||
except futures.TimeoutError:
|
||||
# Server has not been started yet.
|
||||
pass
|
||||
|
||||
def set_result(self, proc: Optional[ProcessInfo]) -> None:
|
||||
"""Set the result of the internal future if it is currently unset."""
|
||||
if not self.process_handle_future.done():
|
||||
self.process_handle_future.set_result(proc)
|
||||
|
||||
|
||||
def _match_running_client_server(command: List[str]) -> bool:
|
||||
|
@ -136,27 +162,37 @@ class ProxyManager():
|
|||
self._session_dir = connection_tuple["session_dir"]
|
||||
return self._session_dir
|
||||
|
||||
def create_specific_server(self, client_id: str) -> SpecificServer:
|
||||
"""
|
||||
Create, but not start a SpecificServer for a given client. This
|
||||
method must be called once per client.
|
||||
"""
|
||||
with self.server_lock:
|
||||
assert self.servers.get(client_id) is None, (
|
||||
f"Server already created for Client: {client_id}")
|
||||
port = self._get_unused_port()
|
||||
server = SpecificServer(
|
||||
port=port,
|
||||
process_handle_future=futures.Future(),
|
||||
channel=grpc.insecure_channel(
|
||||
f"localhost:{port}", options=GRPC_OPTIONS))
|
||||
self.servers[client_id] = server
|
||||
return server
|
||||
|
||||
def start_specific_server(self, client_id: str,
|
||||
job_config: JobConfig) -> bool:
|
||||
"""
|
||||
Start up a RayClient Server for an incoming client to
|
||||
communicate with. Returns whether creation was successful.
|
||||
"""
|
||||
with self.server_lock:
|
||||
port = self._get_unused_port()
|
||||
handle_ready = futures.Future()
|
||||
specific_server = SpecificServer(
|
||||
port=port,
|
||||
process_handle_future=handle_ready,
|
||||
channel=grpc.insecure_channel(
|
||||
f"localhost:{port}", options=GRPC_OPTIONS))
|
||||
self.servers[client_id] = specific_server
|
||||
specific_server = self._get_server_for_client(client_id)
|
||||
assert specific_server, f"Server has not been created for: {client_id}"
|
||||
|
||||
serialized_runtime_env = job_config.get_serialized_runtime_env()
|
||||
|
||||
proc = start_ray_client_server(
|
||||
self._get_redis_address(),
|
||||
port,
|
||||
specific_server.port,
|
||||
fate_share=self.fate_share,
|
||||
server_type="specific-server",
|
||||
serialized_runtime_env=serialized_runtime_env,
|
||||
|
@ -181,9 +217,9 @@ class ProxyManager():
|
|||
logger.debug(
|
||||
"Waiting for Process to reach the actual client server.")
|
||||
time.sleep(0.5)
|
||||
handle_ready.set_result(proc)
|
||||
logger.info(f"SpecificServer started on port: {port} with PID: {pid} "
|
||||
f"for client: {client_id}")
|
||||
specific_server.set_result(proc)
|
||||
logger.info(f"SpecificServer started on port: {specific_server.port} "
|
||||
f"with PID: {pid} for client: {client_id}")
|
||||
return proc.process.poll() is None
|
||||
|
||||
def _get_server_for_client(self,
|
||||
|
@ -205,6 +241,7 @@ class ProxyManager():
|
|||
server = self._get_server_for_client(client_id)
|
||||
if server is None:
|
||||
return None
|
||||
# Wait for the SpecificServer to become ready.
|
||||
server.wait_ready()
|
||||
try:
|
||||
grpc.channel_ready_future(
|
||||
|
@ -221,9 +258,7 @@ class ProxyManager():
|
|||
while True:
|
||||
with self.server_lock:
|
||||
for client_id, specific_server in list(self.servers.items()):
|
||||
poll_result = specific_server.process_handle(
|
||||
).process.poll()
|
||||
if poll_result is not None:
|
||||
if specific_server.poll() is not None:
|
||||
del self.servers[client_id]
|
||||
# Port is available to use again.
|
||||
self._free_ports.append(specific_server.port)
|
||||
|
@ -236,12 +271,7 @@ class ProxyManager():
|
|||
for platforms where fate sharing is not supported.
|
||||
"""
|
||||
for server in self.servers.values():
|
||||
try:
|
||||
server.wait_ready(0.1)
|
||||
server.process_handle().process.kill()
|
||||
except TimeoutError:
|
||||
# Server has not been started yet.
|
||||
pass
|
||||
server.kill()
|
||||
|
||||
|
||||
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
|
@ -367,11 +397,18 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
if client_id == "":
|
||||
return
|
||||
|
||||
# Create Placeholder *before* reading the first request.
|
||||
server = self.proxy_manager.create_specific_server(client_id)
|
||||
try:
|
||||
with self.clients_lock:
|
||||
self.num_clients += 1
|
||||
|
||||
logger.info(f"New data connection from client {client_id}: ")
|
||||
modified_init_req, job_config = prepare_runtime_init_req(
|
||||
request_iterator)
|
||||
|
||||
if not self.proxy_manager.start_specific_server(client_id, job_config):
|
||||
if not self.proxy_manager.start_specific_server(
|
||||
client_id, job_config):
|
||||
logger.error(f"Server startup failed for client: {client_id}, "
|
||||
f"using JobConfig: {job_config}!")
|
||||
context.set_code(grpc.StatusCode.ABORTED)
|
||||
|
@ -383,15 +420,14 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
|||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
return None
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||
try:
|
||||
with self.clients_lock:
|
||||
self.num_clients += 1
|
||||
|
||||
new_iter = chain([modified_init_req], request_iterator)
|
||||
resp_stream = stub.Datapath(
|
||||
new_iter, metadata=[("client_id", client_id)])
|
||||
for resp in resp_stream:
|
||||
yield self.modify_connection_info_resp(resp)
|
||||
finally:
|
||||
server.set_result(None)
|
||||
with self.clients_lock:
|
||||
logger.debug(f"Client detached: {client_id}")
|
||||
self.num_clients -= 1
|
||||
|
@ -406,22 +442,22 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
|||
client_id = _get_client_id_from_context(context)
|
||||
if client_id == "":
|
||||
return
|
||||
logger.debug(f"New data connection from client {client_id}: ")
|
||||
logger.debug(f"New logstream connection from client {client_id}: ")
|
||||
|
||||
channel = None
|
||||
# We need to retry a few times because the LogClient *may* connect
|
||||
# Before the DataClient has finished connecting.
|
||||
for i in range(5):
|
||||
for i in range(LOGSTREAM_RETRIES):
|
||||
channel = self.proxy_manager.get_channel(client_id)
|
||||
|
||||
if channel is not None:
|
||||
break
|
||||
logger.warning(
|
||||
f"Retrying Logstream connection. {i+1} attempts failed.")
|
||||
time.sleep(2)
|
||||
time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)
|
||||
|
||||
if channel is None:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
||||
return None
|
||||
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
||||
|
@ -432,11 +468,13 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
|||
yield resp
|
||||
|
||||
|
||||
def serve_proxier(connection_str: str, redis_address: str):
|
||||
def serve_proxier(connection_str: str,
|
||||
redis_address: str,
|
||||
session_dir: Optional[str] = None):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
|
||||
options=GRPC_OPTIONS)
|
||||
proxy_manager = ProxyManager(redis_address)
|
||||
proxy_manager = ProxyManager(redis_address, session_dir)
|
||||
task_servicer = RayletServicerProxy(None, proxy_manager)
|
||||
data_servicer = DataServicerProxy(proxy_manager)
|
||||
logs_servicer = LogstreamServicerProxy(proxy_manager)
|
||||
|
|
|
@ -341,8 +341,8 @@ class Worker:
|
|||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self.log_client.close()
|
||||
self.data_client.close()
|
||||
self.log_client.close()
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
|
|
Loading…
Add table
Reference in a new issue