mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Client][Proxy] Prevent Logstream from Timing Out when Delays in DataClient (#16180)
This commit is contained in:
parent
fa292a4edf
commit
22bd7cebeb
3 changed files with 124 additions and 50 deletions
|
@ -32,6 +32,7 @@ def test_proxy_manager_lifecycle(shutdown_only):
|
||||||
pm._free_ports = [45000, 45001]
|
pm._free_ports = [45000, 45001]
|
||||||
client = "client1"
|
client = "client1"
|
||||||
|
|
||||||
|
pm.create_specific_server(client)
|
||||||
assert pm.start_specific_server(client, JobConfig())
|
assert pm.start_specific_server(client, JobConfig())
|
||||||
# Channel should be ready and corresponding to an existing server
|
# Channel should be ready and corresponding to an existing server
|
||||||
grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)
|
grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)
|
||||||
|
@ -39,7 +40,7 @@ def test_proxy_manager_lifecycle(shutdown_only):
|
||||||
proc = pm._get_server_for_client(client)
|
proc = pm._get_server_for_client(client)
|
||||||
assert proc.port == 45000
|
assert proc.port == 45000
|
||||||
|
|
||||||
proc.process_handle().process.wait(10)
|
proc.process_handle_future.result().process.wait(10)
|
||||||
# Wait for reconcile loop
|
# Wait for reconcile loop
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
|
@ -63,6 +64,7 @@ def test_proxy_manager_bad_startup(shutdown_only):
|
||||||
pm._free_ports = [46000, 46001]
|
pm._free_ports = [46000, 46001]
|
||||||
client = "client1"
|
client = "client1"
|
||||||
|
|
||||||
|
pm.create_specific_server(client)
|
||||||
assert not pm.start_specific_server(
|
assert not pm.start_specific_server(
|
||||||
client,
|
client,
|
||||||
JobConfig(
|
JobConfig(
|
||||||
|
@ -122,6 +124,40 @@ def test_correct_num_clients(call_ray_start):
|
||||||
run_string_as_driver(check_we_are_second.format(num_clients=1))
|
run_string_as_driver(check_we_are_second.format(num_clients=1))
|
||||||
|
|
||||||
|
|
||||||
|
check_connection = """
|
||||||
|
import ray
|
||||||
|
ray.client("localhost:25010").connect()
|
||||||
|
assert ray.util.client.ray.worker.log_client.log_thread.is_alive()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform == "win32",
|
||||||
|
reason="PSUtil does not work the same on windows.")
|
||||||
|
def test_delay_in_rewriting_environment(shutdown_only):
|
||||||
|
"""
|
||||||
|
Check that a delay in `ray_client_server_env_prep` does not break
|
||||||
|
a Client connecting.
|
||||||
|
"""
|
||||||
|
proxier.LOGSTREAM_RETRIES = 3
|
||||||
|
proxier.LOGSTREAM_RETRY_INTERVAL_SEC = 1
|
||||||
|
ray_instance = ray.init()
|
||||||
|
|
||||||
|
def delay_in_rewrite(input: JobConfig):
|
||||||
|
import time
|
||||||
|
time.sleep(6)
|
||||||
|
return input
|
||||||
|
|
||||||
|
proxier.ray_client_server_env_prep = delay_in_rewrite
|
||||||
|
|
||||||
|
server = proxier.serve_proxier("localhost:25010",
|
||||||
|
ray_instance["redis_address"],
|
||||||
|
ray_instance["session_dir"])
|
||||||
|
|
||||||
|
run_string_as_driver(check_connection)
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_runtime_init_req_fails():
|
def test_prepare_runtime_init_req_fails():
|
||||||
"""
|
"""
|
||||||
Check that a connection that is initiated with a non-Init request
|
Check that a connection that is initiated with a non-Init request
|
||||||
|
|
|
@ -31,6 +31,9 @@ MAX_SPECIFIC_SERVER_PORT = 24000
|
||||||
|
|
||||||
CHECK_CHANNEL_TIMEOUT_S = 10
|
CHECK_CHANNEL_TIMEOUT_S = 10
|
||||||
|
|
||||||
|
LOGSTREAM_RETRIES = 5
|
||||||
|
LOGSTREAM_RETRY_INTERVAL_SEC = 2
|
||||||
|
|
||||||
|
|
||||||
def _get_client_id_from_context(context: Any) -> str:
|
def _get_client_id_from_context(context: Any) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -55,11 +58,34 @@ class SpecificServer:
|
||||||
"""
|
"""
|
||||||
Wait for the server to actually start up.
|
Wait for the server to actually start up.
|
||||||
"""
|
"""
|
||||||
self.process_handle_future.result(timeout=timeout)
|
res = self.process_handle_future.result(timeout=timeout)
|
||||||
return
|
if res is None:
|
||||||
|
# This is only set to none when server creation specifically fails.
|
||||||
|
raise RuntimeError("Server startup failed.")
|
||||||
|
|
||||||
def process_handle(self) -> ProcessInfo:
|
def poll(self) -> Optional[int]:
|
||||||
return self.process_handle_future.result()
|
"""Check if the process has exited."""
|
||||||
|
try:
|
||||||
|
proc = self.process_handle_future.result(timeout=0.1)
|
||||||
|
if proc is not None:
|
||||||
|
return proc.process.poll()
|
||||||
|
except futures.TimeoutError:
|
||||||
|
return
|
||||||
|
|
||||||
|
def kill(self) -> None:
|
||||||
|
"""Try to send a KILL signal to the process."""
|
||||||
|
try:
|
||||||
|
proc = self.process_handle_future.result(timeout=0.1)
|
||||||
|
if proc is not None:
|
||||||
|
proc.process.kill()
|
||||||
|
except futures.TimeoutError:
|
||||||
|
# Server has not been started yet.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_result(self, proc: Optional[ProcessInfo]) -> None:
|
||||||
|
"""Set the result of the internal future if it is currently unset."""
|
||||||
|
if not self.process_handle_future.done():
|
||||||
|
self.process_handle_future.set_result(proc)
|
||||||
|
|
||||||
|
|
||||||
def _match_running_client_server(command: List[str]) -> bool:
|
def _match_running_client_server(command: List[str]) -> bool:
|
||||||
|
@ -136,27 +162,37 @@ class ProxyManager():
|
||||||
self._session_dir = connection_tuple["session_dir"]
|
self._session_dir = connection_tuple["session_dir"]
|
||||||
return self._session_dir
|
return self._session_dir
|
||||||
|
|
||||||
|
def create_specific_server(self, client_id: str) -> SpecificServer:
|
||||||
|
"""
|
||||||
|
Create, but not start a SpecificServer for a given client. This
|
||||||
|
method must be called once per client.
|
||||||
|
"""
|
||||||
|
with self.server_lock:
|
||||||
|
assert self.servers.get(client_id) is None, (
|
||||||
|
f"Server already created for Client: {client_id}")
|
||||||
|
port = self._get_unused_port()
|
||||||
|
server = SpecificServer(
|
||||||
|
port=port,
|
||||||
|
process_handle_future=futures.Future(),
|
||||||
|
channel=grpc.insecure_channel(
|
||||||
|
f"localhost:{port}", options=GRPC_OPTIONS))
|
||||||
|
self.servers[client_id] = server
|
||||||
|
return server
|
||||||
|
|
||||||
def start_specific_server(self, client_id: str,
|
def start_specific_server(self, client_id: str,
|
||||||
job_config: JobConfig) -> bool:
|
job_config: JobConfig) -> bool:
|
||||||
"""
|
"""
|
||||||
Start up a RayClient Server for an incoming client to
|
Start up a RayClient Server for an incoming client to
|
||||||
communicate with. Returns whether creation was successful.
|
communicate with. Returns whether creation was successful.
|
||||||
"""
|
"""
|
||||||
with self.server_lock:
|
specific_server = self._get_server_for_client(client_id)
|
||||||
port = self._get_unused_port()
|
assert specific_server, f"Server has not been created for: {client_id}"
|
||||||
handle_ready = futures.Future()
|
|
||||||
specific_server = SpecificServer(
|
|
||||||
port=port,
|
|
||||||
process_handle_future=handle_ready,
|
|
||||||
channel=grpc.insecure_channel(
|
|
||||||
f"localhost:{port}", options=GRPC_OPTIONS))
|
|
||||||
self.servers[client_id] = specific_server
|
|
||||||
|
|
||||||
serialized_runtime_env = job_config.get_serialized_runtime_env()
|
serialized_runtime_env = job_config.get_serialized_runtime_env()
|
||||||
|
|
||||||
proc = start_ray_client_server(
|
proc = start_ray_client_server(
|
||||||
self._get_redis_address(),
|
self._get_redis_address(),
|
||||||
port,
|
specific_server.port,
|
||||||
fate_share=self.fate_share,
|
fate_share=self.fate_share,
|
||||||
server_type="specific-server",
|
server_type="specific-server",
|
||||||
serialized_runtime_env=serialized_runtime_env,
|
serialized_runtime_env=serialized_runtime_env,
|
||||||
|
@ -181,9 +217,9 @@ class ProxyManager():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Waiting for Process to reach the actual client server.")
|
"Waiting for Process to reach the actual client server.")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
handle_ready.set_result(proc)
|
specific_server.set_result(proc)
|
||||||
logger.info(f"SpecificServer started on port: {port} with PID: {pid} "
|
logger.info(f"SpecificServer started on port: {specific_server.port} "
|
||||||
f"for client: {client_id}")
|
f"with PID: {pid} for client: {client_id}")
|
||||||
return proc.process.poll() is None
|
return proc.process.poll() is None
|
||||||
|
|
||||||
def _get_server_for_client(self,
|
def _get_server_for_client(self,
|
||||||
|
@ -205,6 +241,7 @@ class ProxyManager():
|
||||||
server = self._get_server_for_client(client_id)
|
server = self._get_server_for_client(client_id)
|
||||||
if server is None:
|
if server is None:
|
||||||
return None
|
return None
|
||||||
|
# Wait for the SpecificServer to become ready.
|
||||||
server.wait_ready()
|
server.wait_ready()
|
||||||
try:
|
try:
|
||||||
grpc.channel_ready_future(
|
grpc.channel_ready_future(
|
||||||
|
@ -221,9 +258,7 @@ class ProxyManager():
|
||||||
while True:
|
while True:
|
||||||
with self.server_lock:
|
with self.server_lock:
|
||||||
for client_id, specific_server in list(self.servers.items()):
|
for client_id, specific_server in list(self.servers.items()):
|
||||||
poll_result = specific_server.process_handle(
|
if specific_server.poll() is not None:
|
||||||
).process.poll()
|
|
||||||
if poll_result is not None:
|
|
||||||
del self.servers[client_id]
|
del self.servers[client_id]
|
||||||
# Port is available to use again.
|
# Port is available to use again.
|
||||||
self._free_ports.append(specific_server.port)
|
self._free_ports.append(specific_server.port)
|
||||||
|
@ -236,12 +271,7 @@ class ProxyManager():
|
||||||
for platforms where fate sharing is not supported.
|
for platforms where fate sharing is not supported.
|
||||||
"""
|
"""
|
||||||
for server in self.servers.values():
|
for server in self.servers.values():
|
||||||
try:
|
server.kill()
|
||||||
server.wait_ready(0.1)
|
|
||||||
server.process_handle().process.kill()
|
|
||||||
except TimeoutError:
|
|
||||||
# Server has not been started yet.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
|
||||||
|
@ -367,31 +397,37 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||||
if client_id == "":
|
if client_id == "":
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"New data connection from client {client_id}: ")
|
# Create Placeholder *before* reading the first request.
|
||||||
modified_init_req, job_config = prepare_runtime_init_req(
|
server = self.proxy_manager.create_specific_server(client_id)
|
||||||
request_iterator)
|
|
||||||
|
|
||||||
if not self.proxy_manager.start_specific_server(client_id, job_config):
|
|
||||||
logger.error(f"Server startup failed for client: {client_id}, "
|
|
||||||
f"using JobConfig: {job_config}!")
|
|
||||||
context.set_code(grpc.StatusCode.ABORTED)
|
|
||||||
return None
|
|
||||||
|
|
||||||
channel = self.proxy_manager.get_channel(client_id)
|
|
||||||
if channel is None:
|
|
||||||
logger.error(f"Channel not found for {client_id}")
|
|
||||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
|
||||||
return None
|
|
||||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
|
||||||
try:
|
try:
|
||||||
with self.clients_lock:
|
with self.clients_lock:
|
||||||
self.num_clients += 1
|
self.num_clients += 1
|
||||||
|
|
||||||
|
logger.info(f"New data connection from client {client_id}: ")
|
||||||
|
modified_init_req, job_config = prepare_runtime_init_req(
|
||||||
|
request_iterator)
|
||||||
|
|
||||||
|
if not self.proxy_manager.start_specific_server(
|
||||||
|
client_id, job_config):
|
||||||
|
logger.error(f"Server startup failed for client: {client_id}, "
|
||||||
|
f"using JobConfig: {job_config}!")
|
||||||
|
context.set_code(grpc.StatusCode.ABORTED)
|
||||||
|
return None
|
||||||
|
|
||||||
|
channel = self.proxy_manager.get_channel(client_id)
|
||||||
|
if channel is None:
|
||||||
|
logger.error(f"Channel not found for {client_id}")
|
||||||
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||||
|
return None
|
||||||
|
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
|
||||||
|
|
||||||
new_iter = chain([modified_init_req], request_iterator)
|
new_iter = chain([modified_init_req], request_iterator)
|
||||||
resp_stream = stub.Datapath(
|
resp_stream = stub.Datapath(
|
||||||
new_iter, metadata=[("client_id", client_id)])
|
new_iter, metadata=[("client_id", client_id)])
|
||||||
for resp in resp_stream:
|
for resp in resp_stream:
|
||||||
yield self.modify_connection_info_resp(resp)
|
yield self.modify_connection_info_resp(resp)
|
||||||
finally:
|
finally:
|
||||||
|
server.set_result(None)
|
||||||
with self.clients_lock:
|
with self.clients_lock:
|
||||||
logger.debug(f"Client detached: {client_id}")
|
logger.debug(f"Client detached: {client_id}")
|
||||||
self.num_clients -= 1
|
self.num_clients -= 1
|
||||||
|
@ -406,22 +442,22 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||||
client_id = _get_client_id_from_context(context)
|
client_id = _get_client_id_from_context(context)
|
||||||
if client_id == "":
|
if client_id == "":
|
||||||
return
|
return
|
||||||
logger.debug(f"New data connection from client {client_id}: ")
|
logger.debug(f"New logstream connection from client {client_id}: ")
|
||||||
|
|
||||||
channel = None
|
channel = None
|
||||||
# We need to retry a few times because the LogClient *may* connect
|
# We need to retry a few times because the LogClient *may* connect
|
||||||
# Before the DataClient has finished connecting.
|
# Before the DataClient has finished connecting.
|
||||||
for i in range(5):
|
for i in range(LOGSTREAM_RETRIES):
|
||||||
channel = self.proxy_manager.get_channel(client_id)
|
channel = self.proxy_manager.get_channel(client_id)
|
||||||
|
|
||||||
if channel is not None:
|
if channel is not None:
|
||||||
break
|
break
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Retrying Logstream connection. {i+1} attempts failed.")
|
f"Retrying Logstream connection. {i+1} attempts failed.")
|
||||||
time.sleep(2)
|
time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)
|
||||||
|
|
||||||
if channel is None:
|
if channel is None:
|
||||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
|
||||||
|
@ -432,11 +468,13 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
def serve_proxier(connection_str: str, redis_address: str):
|
def serve_proxier(connection_str: str,
|
||||||
|
redis_address: str,
|
||||||
|
session_dir: Optional[str] = None):
|
||||||
server = grpc.server(
|
server = grpc.server(
|
||||||
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
|
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
|
||||||
options=GRPC_OPTIONS)
|
options=GRPC_OPTIONS)
|
||||||
proxy_manager = ProxyManager(redis_address)
|
proxy_manager = ProxyManager(redis_address, session_dir)
|
||||||
task_servicer = RayletServicerProxy(None, proxy_manager)
|
task_servicer = RayletServicerProxy(None, proxy_manager)
|
||||||
data_servicer = DataServicerProxy(proxy_manager)
|
data_servicer = DataServicerProxy(proxy_manager)
|
||||||
logs_servicer = LogstreamServicerProxy(proxy_manager)
|
logs_servicer = LogstreamServicerProxy(proxy_manager)
|
||||||
|
|
|
@ -341,8 +341,8 @@ class Worker:
|
||||||
self.reference_count[id] += 1
|
self.reference_count[id] += 1
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.log_client.close()
|
|
||||||
self.data_client.close()
|
self.data_client.close()
|
||||||
|
self.log_client.close()
|
||||||
if self.channel:
|
if self.channel:
|
||||||
self.channel.close()
|
self.channel.close()
|
||||||
self.channel = None
|
self.channel = None
|
||||||
|
|
Loading…
Add table
Reference in a new issue