[debugger] Clean up breakpoint state for dead jobs (#17095)

This commit is contained in:
Edward Oakes 2021-07-15 22:20:09 -05:00 committed by GitHub
parent 2a53d22438
commit 90a1667b29
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 7 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Set
import click
import copy
@ -166,7 +166,7 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port,
from None
def continue_debug_session():
def continue_debug_session(live_jobs: Set[str]):
"""Continue active debugging session.
This function will connect 'ray debug' to the right debugger
@ -177,19 +177,27 @@ def continue_debug_session():
for active_session in active_sessions:
if active_session.startswith(b"RAY_PDB_CONTINUE"):
# Check to see that the relevant job is still alive.
data = ray.experimental.internal_kv._internal_kv_get(
active_session)
if json.loads(data)["job_id"] not in live_jobs:
ray.experimental.internal_kv._internal_kv_del(active_session)
continue
print("Continuing pdb session in different process...")
key = b"RAY_PDB_" + active_session[len("RAY_PDB_CONTINUE_"):]
while True:
data = ray.experimental.internal_kv._internal_kv_get(key)
if data:
session = json.loads(data)
if "exit_debugger" in session:
if ("exit_debugger" in session
or session["job_id"] not in live_jobs):
ray.experimental.internal_kv._internal_kv_del(key)
return
host, port = session["pdb_address"].split(":")
ray.util.rpdb.connect_pdb_client(host, int(port))
ray.experimental.internal_kv._internal_kv_del(key)
continue_debug_session()
continue_debug_session(live_jobs)
return
time.sleep(1.0)
@ -217,7 +225,12 @@ def debug(address):
logger.info(f"Connecting to Ray instance at {address}.")
ray.init(address=address, log_to_driver=False)
while True:
continue_debug_session()
# Used to filter out and clean up entries from dead jobs.
live_jobs = {
job["JobID"]
for job in ray.state.jobs() if not job["IsDead"]
}
continue_debug_session(live_jobs)
active_sessions = ray.experimental.internal_kv._internal_kv_list(
"RAY_PDB_")
@ -226,7 +239,11 @@ def debug(address):
for active_session in active_sessions:
data = json.loads(
ray.experimental.internal_kv._internal_kv_get(active_session))
sessions_data.append(data)
# Check that the relevant job is alive, else clean up the entry.
if data["job_id"] in live_jobs:
sessions_data.append(data)
else:
ray.experimental.internal_kv._internal_kv_del(active_session)
sessions_data = sorted(
sessions_data, key=lambda data: data["timestamp"], reverse=True)
table = [["index", "timestamp", "Ray task", "filename:lineno"]]

View file

@ -6,7 +6,9 @@ from telnetlib import Telnet
import pexpect
import pytest
import ray
from ray.test_utils import run_string_as_driver, wait_for_condition
def test_ray_debugger_breakpoint(shutdown_only):
@ -134,6 +136,46 @@ def test_ray_debugger_recursive(shutdown_only):
ray.get(result)
@pytest.mark.skipif(
platform.system() == "Windows", reason="Failing on Windows.")
def test_job_exit_cleanup(ray_start_regular):
address = ray_start_regular["redis_address"]
driver_script = """
import time
import ray
ray.init(address="{}")
@ray.remote
def f():
ray.util.rpdb.set_trace()
f.remote()
# Give the remote function long enough to actually run.
time.sleep(5)
""".format(address)
assert not len(ray.experimental.internal_kv._internal_kv_list("RAY_PDB_"))
run_string_as_driver(driver_script)
def one_active_session():
return len(ray.experimental.internal_kv._internal_kv_list("RAY_PDB_"))
wait_for_condition(one_active_session)
# Start the debugger. This should clean up any existing sessions that
# belong to dead jobs.
p = pexpect.spawn("ray debug") # noqa:F841
def no_active_sessions():
return not len(
ray.experimental.internal_kv._internal_kv_list("RAY_PDB_"))
wait_for_condition(no_active_sessions)
if __name__ == "__main__":
import pytest
# Make subprocess happy in bazel.

View file

@ -165,8 +165,11 @@ class RemotePdb(Pdb):
# Tell the next task to drop into the debugger.
ray.worker.global_worker.debugger_breakpoint = self._breakpoint_uuid
# Tell the debug loop to connect to the next task.
data = json.dumps({
"job_id": ray.get_runtime_context().job_id.hex(),
})
_internal_kv_put("RAY_PDB_CONTINUE_{}".format(self._breakpoint_uuid),
"")
data)
self.__restore()
self.handle.connection.close()
return Pdb.do_continue(self, arg)
@ -214,6 +217,7 @@ def connect_ray_pdb(host=None,
"lineno": parentframeinfo.lineno,
"traceback": "\n".join(traceback.format_exception(*sys.exc_info())),
"timestamp": time.time(),
"job_id": ray.get_runtime_context().job_id.hex(),
}
_internal_kv_put(
"RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True)