[Client][Proxy] Prevent Logstream from Timing Out when Delays in DataClient (#16180)

This commit is contained in:
Ian Rodney 2021-06-03 11:59:52 -07:00 committed by GitHub
parent fa292a4edf
commit 22bd7cebeb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 50 deletions

View file

@ -32,6 +32,7 @@ def test_proxy_manager_lifecycle(shutdown_only):
pm._free_ports = [45000, 45001]
client = "client1"
pm.create_specific_server(client)
assert pm.start_specific_server(client, JobConfig())
# Channel should be ready and corresponding to an existing server
grpc.channel_ready_future(pm.get_channel(client)).result(timeout=5)
@ -39,7 +40,7 @@ def test_proxy_manager_lifecycle(shutdown_only):
proc = pm._get_server_for_client(client)
assert proc.port == 45000
proc.process_handle().process.wait(10)
proc.process_handle_future.result().process.wait(10)
# Wait for reconcile loop
time.sleep(2)
@ -63,6 +64,7 @@ def test_proxy_manager_bad_startup(shutdown_only):
pm._free_ports = [46000, 46001]
client = "client1"
pm.create_specific_server(client)
assert not pm.start_specific_server(
client,
JobConfig(
@ -122,6 +124,40 @@ def test_correct_num_clients(call_ray_start):
run_string_as_driver(check_we_are_second.format(num_clients=1))
check_connection = """
import ray
ray.client("localhost:25010").connect()
assert ray.util.client.ray.worker.log_client.log_thread.is_alive()
"""
@pytest.mark.skipif(
sys.platform == "win32",
reason="PSUtil does not work the same on windows.")
def test_delay_in_rewriting_environment(shutdown_only):
"""
Check that a delay in `ray_client_server_env_prep` does not break
a Client connecting.
"""
proxier.LOGSTREAM_RETRIES = 3
proxier.LOGSTREAM_RETRY_INTERVAL_SEC = 1
ray_instance = ray.init()
def delay_in_rewrite(input: JobConfig):
import time
time.sleep(6)
return input
proxier.ray_client_server_env_prep = delay_in_rewrite
server = proxier.serve_proxier("localhost:25010",
ray_instance["redis_address"],
ray_instance["session_dir"])
run_string_as_driver(check_connection)
server.stop(0)
def test_prepare_runtime_init_req_fails():
"""
Check that a connection that is initiated with a non-Init request

View file

@ -31,6 +31,9 @@ MAX_SPECIFIC_SERVER_PORT = 24000
CHECK_CHANNEL_TIMEOUT_S = 10
LOGSTREAM_RETRIES = 5
LOGSTREAM_RETRY_INTERVAL_SEC = 2
def _get_client_id_from_context(context: Any) -> str:
"""
@ -55,11 +58,34 @@ class SpecificServer:
"""
Wait for the server to actually start up.
"""
self.process_handle_future.result(timeout=timeout)
res = self.process_handle_future.result(timeout=timeout)
if res is None:
# This is only set to none when server creation specifically fails.
raise RuntimeError("Server startup failed.")
def poll(self) -> Optional[int]:
"""Check if the process has exited."""
try:
proc = self.process_handle_future.result(timeout=0.1)
if proc is not None:
return proc.process.poll()
except futures.TimeoutError:
return
def process_handle(self) -> ProcessInfo:
return self.process_handle_future.result()
def kill(self) -> None:
"""Try to send a KILL signal to the process."""
try:
proc = self.process_handle_future.result(timeout=0.1)
if proc is not None:
proc.process.kill()
except futures.TimeoutError:
# Server has not been started yet.
pass
def set_result(self, proc: Optional[ProcessInfo]) -> None:
"""Set the result of the internal future if it is currently unset."""
if not self.process_handle_future.done():
self.process_handle_future.set_result(proc)
def _match_running_client_server(command: List[str]) -> bool:
@ -136,27 +162,37 @@ class ProxyManager():
self._session_dir = connection_tuple["session_dir"]
return self._session_dir
def create_specific_server(self, client_id: str) -> SpecificServer:
"""
Create, but not start a SpecificServer for a given client. This
method must be called once per client.
"""
with self.server_lock:
assert self.servers.get(client_id) is None, (
f"Server already created for Client: {client_id}")
port = self._get_unused_port()
server = SpecificServer(
port=port,
process_handle_future=futures.Future(),
channel=grpc.insecure_channel(
f"localhost:{port}", options=GRPC_OPTIONS))
self.servers[client_id] = server
return server
def start_specific_server(self, client_id: str,
job_config: JobConfig) -> bool:
"""
Start up a RayClient Server for an incoming client to
communicate with. Returns whether creation was successful.
"""
with self.server_lock:
port = self._get_unused_port()
handle_ready = futures.Future()
specific_server = SpecificServer(
port=port,
process_handle_future=handle_ready,
channel=grpc.insecure_channel(
f"localhost:{port}", options=GRPC_OPTIONS))
self.servers[client_id] = specific_server
specific_server = self._get_server_for_client(client_id)
assert specific_server, f"Server has not been created for: {client_id}"
serialized_runtime_env = job_config.get_serialized_runtime_env()
proc = start_ray_client_server(
self._get_redis_address(),
port,
specific_server.port,
fate_share=self.fate_share,
server_type="specific-server",
serialized_runtime_env=serialized_runtime_env,
@ -181,9 +217,9 @@ class ProxyManager():
logger.debug(
"Waiting for Process to reach the actual client server.")
time.sleep(0.5)
handle_ready.set_result(proc)
logger.info(f"SpecificServer started on port: {port} with PID: {pid} "
f"for client: {client_id}")
specific_server.set_result(proc)
logger.info(f"SpecificServer started on port: {specific_server.port} "
f"with PID: {pid} for client: {client_id}")
return proc.process.poll() is None
def _get_server_for_client(self,
@ -205,6 +241,7 @@ class ProxyManager():
server = self._get_server_for_client(client_id)
if server is None:
return None
# Wait for the SpecificServer to become ready.
server.wait_ready()
try:
grpc.channel_ready_future(
@ -221,9 +258,7 @@ class ProxyManager():
while True:
with self.server_lock:
for client_id, specific_server in list(self.servers.items()):
poll_result = specific_server.process_handle(
).process.poll()
if poll_result is not None:
if specific_server.poll() is not None:
del self.servers[client_id]
# Port is available to use again.
self._free_ports.append(specific_server.port)
@ -236,12 +271,7 @@ class ProxyManager():
for platforms where fate sharing is not supported.
"""
for server in self.servers.values():
try:
server.wait_ready(0.1)
server.process_handle().process.kill()
except TimeoutError:
# Server has not been started yet.
pass
server.kill()
class RayletServicerProxy(ray_client_pb2_grpc.RayletDriverServicer):
@ -367,11 +397,18 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
if client_id == "":
return
# Create Placeholder *before* reading the first request.
server = self.proxy_manager.create_specific_server(client_id)
try:
with self.clients_lock:
self.num_clients += 1
logger.info(f"New data connection from client {client_id}: ")
modified_init_req, job_config = prepare_runtime_init_req(
request_iterator)
if not self.proxy_manager.start_specific_server(client_id, job_config):
if not self.proxy_manager.start_specific_server(
client_id, job_config):
logger.error(f"Server startup failed for client: {client_id}, "
f"using JobConfig: {job_config}!")
context.set_code(grpc.StatusCode.ABORTED)
@ -383,15 +420,14 @@ class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
context.set_code(grpc.StatusCode.NOT_FOUND)
return None
stub = ray_client_pb2_grpc.RayletDataStreamerStub(channel)
try:
with self.clients_lock:
self.num_clients += 1
new_iter = chain([modified_init_req], request_iterator)
resp_stream = stub.Datapath(
new_iter, metadata=[("client_id", client_id)])
for resp in resp_stream:
yield self.modify_connection_info_resp(resp)
finally:
server.set_result(None)
with self.clients_lock:
logger.debug(f"Client detached: {client_id}")
self.num_clients -= 1
@ -406,22 +442,22 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
client_id = _get_client_id_from_context(context)
if client_id == "":
return
logger.debug(f"New data connection from client {client_id}: ")
logger.debug(f"New logstream connection from client {client_id}: ")
channel = None
# We need to retry a few times because the LogClient *may* connect
# Before the DataClient has finished connecting.
for i in range(5):
for i in range(LOGSTREAM_RETRIES):
channel = self.proxy_manager.get_channel(client_id)
if channel is not None:
break
logger.warning(
f"Retrying Logstream connection. {i+1} attempts failed.")
time.sleep(2)
time.sleep(LOGSTREAM_RETRY_INTERVAL_SEC)
if channel is None:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_code(grpc.StatusCode.UNAVAILABLE)
return None
stub = ray_client_pb2_grpc.RayletLogStreamerStub(channel)
@ -432,11 +468,13 @@ class LogstreamServicerProxy(ray_client_pb2_grpc.RayletLogStreamerServicer):
yield resp
def serve_proxier(connection_str: str, redis_address: str):
def serve_proxier(connection_str: str,
redis_address: str,
session_dir: Optional[str] = None):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=CLIENT_SERVER_MAX_THREADS),
options=GRPC_OPTIONS)
proxy_manager = ProxyManager(redis_address)
proxy_manager = ProxyManager(redis_address, session_dir)
task_servicer = RayletServicerProxy(None, proxy_manager)
data_servicer = DataServicerProxy(proxy_manager)
logs_servicer = LogstreamServicerProxy(proxy_manager)

View file

@ -341,8 +341,8 @@ class Worker:
self.reference_count[id] += 1
def close(self):
self.log_client.close()
self.data_client.close()
self.log_client.close()
if self.channel:
self.channel.close()
self.channel = None