[jobs] Monitor jobs in the background to avoid requiring clients to poll (#22180)

This commit is contained in:
Edward Oakes 2022-02-07 15:25:25 -06:00 committed by GitHub
parent 8e1e783596
commit 8806b2d5c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 251 additions and 143 deletions

View file

@ -7,6 +7,7 @@ from ray import ray_constants
from ray.experimental.internal_kv import (
_internal_kv_initialized,
_internal_kv_get,
_internal_kv_list,
_internal_kv_put,
)
from ray._private.runtime_env.packaging import parse_uri
@ -21,15 +22,18 @@ CURRENT_VERSION = "1"
class JobStatus(str, Enum):
def __str__(self):
return f"{self.value}"
PENDING = "PENDING"
RUNNING = "RUNNING"
STOPPED = "STOPPED"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
def __str__(self):
return f"{self.value}"
def is_terminal(self):
return self.value in {"STOPPED", "SUCCEEDED", "FAILED"}
@dataclass
class JobStatusInfo:
@ -58,7 +62,8 @@ class JobStatusStorageClient:
Handles formatting of status storage key given job id.
"""
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
JOB_STATUS_KEY_PREFIX = "_ray_internal_job_status"
JOB_STATUS_KEY = f"{JOB_STATUS_KEY_PREFIX}_{{job_id}}"
def __init__(self):
assert _internal_kv_initialized()
@ -85,6 +90,11 @@ class JobStatusStorageClient:
else:
return pickle.loads(pickled_status)
def get_all_jobs(self) -> Dict[str, JobStatusInfo]:
raw_job_ids = _internal_kv_list(self.JOB_STATUS_KEY_PREFIX)
job_ids = [job_id.decode() for job_id in raw_job_ids]
return {job_id: self.get_status(job_id) for job_id in job_ids}
def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
if not package_uri.endswith(".zip"):

View file

@ -115,11 +115,8 @@ class JobSupervisor:
# fire and forget call from outer job manager to this actor
self._stop_event = asyncio.Event()
def ready(self):
"""Dummy object ref. Return of this function represents job supervisor
actor stated successfully with runtime_env configured, and is ready to
move on to running state.
"""
def ping(self):
"""Used to check the health of the actor."""
pass
def _exec_entrypoint(self, logs_path: str) -> subprocess.Popen:
@ -193,8 +190,10 @@ class JobSupervisor:
variables.
3) Handle concurrent events of driver execution and
"""
cur_status = self._get_status()
assert cur_status.status == JobStatus.PENDING, "Run should only be called once."
curr_status = self._status_client.get_status(self._job_id)
assert (
curr_status.status == JobStatus.PENDING
), "Run should only be called once."
if _start_signal_actor:
# Block in PENDING state until start signal received.
@ -269,9 +268,6 @@ class JobSupervisor:
# clean up actor after tasks are finished
ray.actor.exit_actor()
def _get_status(self) -> Optional[JobStatusInfo]:
return self._status_client.get_status(self._job_id)
def stop(self):
"""Set step_event and let run() handle the rest in its asyncio.wait()."""
self._stop_event.set()
@ -288,18 +284,92 @@ class JobManager:
# Time that we will sleep while tailing logs if no new log line is
# available.
LOG_TAIL_SLEEP_S = 1
JOB_MONITOR_LOOP_PERIOD_S = 1
def __init__(self):
self._status_client = JobStatusStorageClient()
self._log_client = JobLogStorageClient()
self._supervisor_actor_cls = ray.remote(JobSupervisor)
self._recover_running_jobs()
def _recover_running_jobs(self):
"""Recovers all running jobs from the status client.
For each job, we will spawn a coroutine to monitor it.
Each will be added to self._running_jobs and reconciled.
"""
all_jobs = self._status_client.get_all_jobs()
for job_id, status_info in all_jobs.items():
if not status_info.status.is_terminal():
create_task(self._monitor_job(job_id))
def _get_actor_for_job(self, job_id: str) -> Optional[ActorHandle]:
try:
return ray.get_actor(self.JOB_ACTOR_NAME.format(job_id=job_id))
except ValueError: # Ray returns ValueError for nonexistent actor.
return None
async def _monitor_job(
self, job_id: str, job_supervisor: Optional[ActorHandle] = None
):
"""Monitors the specified job until it enters a terminal state.
This is necessary because we need to handle the case where the
JobSupervisor dies unexpectedly.
"""
is_alive = True
if job_supervisor is None:
job_supervisor = self._get_actor_for_job(job_id)
if job_supervisor is None:
logger.error(f"Failed to get job supervisor for job {job_id}.")
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=(
"Unexpected error occurred: Failed to get job supervisor."
),
),
)
is_alive = False
while is_alive:
try:
await job_supervisor.ping.remote()
await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S)
except Exception as e:
is_alive = False
if self._status_client.get_status(job_id).status.is_terminal():
# If the job is already in a terminal state, then the actor
# exiting is expected.
pass
elif isinstance(e, RuntimeEnvSetupError):
logger.info(f"Failed to set up runtime_env for job {job_id}.")
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=(f"runtime_env setup failed: {e}"),
),
)
else:
logger.warning(
f"Job supervisor for job {job_id} failed unexpectedly: {e}."
)
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=f"Unexpected error occurred: {e}",
),
)
# Kill the actor defensively to avoid leaking actors in unexpected error cases.
if job_supervisor is not None:
ray.kill(job_supervisor, no_restart=True)
def _get_current_node_resource_key(self) -> str:
"""Get the Ray resource key for current node.
@ -326,26 +396,6 @@ class JobManager:
"""
if result is None:
return
elif isinstance(result, RuntimeEnvSetupError):
logger.info(f"Failed to set up runtime_env for job {job_id}.")
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=(f"runtime_env setup failed: {result}"),
),
)
elif isinstance(result, Exception):
logger.error(f"Failed to start supervisor for job {job_id}: {result}.")
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=f"Error occurred while starting the job: {result}",
),
)
else:
assert False, "This should not be reached."
def submit_job(
self,
@ -394,9 +444,9 @@ class JobManager:
# Wait for the actor to start up asynchronously so this call always
# returns immediately and we can catch errors with the actor starting
# up. We may want to put this in an actor instead in the future.
# up.
try:
actor = self._supervisor_actor_cls.options(
supervisor = self._supervisor_actor_cls.options(
lifetime="detached",
name=self.JOB_ACTOR_NAME.format(job_id=job_id),
num_cpus=0,
@ -407,26 +457,26 @@ class JobManager:
},
runtime_env=runtime_env,
).remote(job_id, entrypoint, metadata or {})
actor.run.remote(_start_signal_actor=_start_signal_actor)
supervisor.run.remote(_start_signal_actor=_start_signal_actor)
def callback(result: Optional[Exception]):
return self._handle_supervisor_startup(job_id, result)
actor.ready.remote()._on_completed(callback)
# Monitor the job in the background so we can detect errors without
# requiring a client to poll.
create_task(self._monitor_job(job_id, job_supervisor=supervisor))
except Exception as e:
self._handle_supervisor_startup(job_id, e)
self._status_client.put_status(
job_id,
JobStatusInfo(
status=JobStatus.FAILED,
message=f"Failed to start job supervisor: {e}.",
),
)
return job_id
def stop_job(self, job_id) -> bool:
"""Request job to exit, fire and forget.
"""Request a job to exit, fire and forget.
Args:
job_id: ID of the job.
Returns:
stopped:
True if there's running job
False if no running job found
Returns whether or not the job was running.
"""
job_supervisor_actor = self._get_actor_for_job(job_id)
if job_supervisor_actor is not None:
@ -438,35 +488,15 @@ class JobManager:
return False
def get_job_status(self, job_id: str) -> Optional[JobStatus]:
"""Get latest status of a job. If job supervisor actor is no longer
alive, it will also attempt to make adjustments needed to bring job
to correct terminiation state.
All job status is stored and read only from GCS.
Args:
job_id: ID of the job.
Returns:
job_status: Latest known job status
"""
job_supervisor_actor = self._get_actor_for_job(job_id)
if job_supervisor_actor is None:
# Job actor either exited or failed, we need to ensure never
# left job in non-terminal status in case actor failed without
# updating GCS with latest status.
last_status = self._status_client.get_status(job_id)
if last_status and last_status.status in {
JobStatus.PENDING,
JobStatus.RUNNING,
}:
self._status_client.put_status(job_id, JobStatus.FAILED)
"""Get latest status of a job."""
return self._status_client.get_status(job_id)
def get_job_logs(self, job_id: str) -> str:
"""Get all logs produced by a job."""
return self._log_client.get_logs(job_id)
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
"""Return an iterator following the logs of a job."""
if self.get_job_status(job_id) is None:
raise RuntimeError(f"Job '{job_id}' does not exist.")

View file

@ -1,3 +1,4 @@
import asyncio
import os
import psutil
import tempfile
@ -14,7 +15,7 @@ from ray.dashboard.modules.job.common import (
JOB_NAME_METADATA_KEY,
)
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
from ray._private.test_utils import SignalActor, wait_for_condition
from ray._private.test_utils import SignalActor, async_wait_for_condition
TEST_NAMESPACE = "jobs_test_namespace"
@ -40,7 +41,7 @@ def _driver_script_path(file_name: str) -> str:
)
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
tmp_file = os.path.join(tmp_dir, "hello")
pid_file = os.path.join(tmp_dir, "pid")
@ -61,10 +62,14 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
assert status.status == JobStatus.PENDING
logs = job_manager.get_job_logs(job_id)
assert logs == ""
await asyncio.sleep(0.01)
else:
wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id)
wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id))
await async_wait_for_condition(
check_job_running, job_manager=job_manager, job_id=job_id
)
await async_wait_for_condition(
lambda: "Waiting..." in job_manager.get_job_logs(job_id)
)
return pid_file, tmp_file, job_id
@ -112,40 +117,50 @@ def test_generate_job_id():
assert len(ids) == 10000
def test_pass_job_id(job_manager):
@pytest.mark.asyncio
async def test_pass_job_id(job_manager):
job_id = "my_custom_id"
returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
assert returned_id == job_id
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
# Check that the same job_id is rejected.
with pytest.raises(RuntimeError):
job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
@pytest.mark.asyncio
class TestShellScriptExecution:
def test_submit_basic_echo(self, job_manager):
async def test_submit_basic_echo(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo hello")
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "hello\n"
def test_submit_stderr(self, job_manager):
async def test_submit_stderr(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "error\n"
def test_submit_ls_grep(self, job_manager):
async def test_submit_ls_grep(self, job_manager):
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
job_id = job_manager.submit_job(entrypoint=grep_cmd)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
def test_subprocess_exception(self, job_manager):
async def test_subprocess_exception(self, job_manager):
"""
Run a python script with exception, ensure:
1) Job status is marked as failed
@ -165,26 +180,25 @@ class TestShellScriptExecution:
return job_manager._get_actor_for_job(job_id) is None
wait_for_condition(cleaned_up)
await async_wait_for_condition(cleaned_up)
def test_submit_with_s3_runtime_env(self, job_manager):
async def test_submit_with_s3_runtime_env(self, job_manager):
job_id = job_manager.submit_job(
entrypoint="python script.py",
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert (
job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n"
)
@pytest.mark.asyncio
class TestRuntimeEnv:
def test_inheritance(self, job_manager):
# Test that the driver and actors/tasks inherit the right runtime_env.
pass
def test_pass_env_var(self, job_manager):
async def test_pass_env_var(self, job_manager):
"""Test we can pass env vars in the subprocess that executes job's
driver script.
"""
@ -193,10 +207,12 @@ class TestRuntimeEnv:
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "233\n"
def test_multiple_runtime_envs(self, job_manager):
async def test_multiple_runtime_envs(self, job_manager):
# Test that you can run two jobs in different envs without conflict.
job_id_1 = job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
@ -205,7 +221,7 @@ class TestRuntimeEnv:
},
)
wait_for_condition(
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id_1
)
logs = job_manager.get_job_logs(job_id_1)
@ -220,7 +236,7 @@ class TestRuntimeEnv:
},
)
wait_for_condition(
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id_2
)
logs = job_manager.get_job_logs(job_id_2)
@ -228,7 +244,7 @@ class TestRuntimeEnv:
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs
) # noqa: E501
def test_env_var_and_driver_job_config_warning(self, job_manager):
async def test_env_var_and_driver_job_config_warning(self, job_manager):
"""Ensure we got error message from worker.py and job logs
if user provided runtime_env in both driver script and submit()
"""
@ -239,14 +255,16 @@ class TestRuntimeEnv:
},
)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
logs = job_manager.get_job_logs(job_id)
assert logs.startswith(
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided"
)
assert "JOB_1_VAR" in logs
def test_failed_runtime_env_validation(self, job_manager):
async def test_failed_runtime_env_validation(self, job_manager):
"""Ensure job status is correctly set as failed if job has an invalid
runtime_env.
"""
@ -259,7 +277,7 @@ class TestRuntimeEnv:
assert status.status == JobStatus.FAILED
assert "path_not_exist is not a valid URI" in status.message
def test_failed_runtime_env_setup(self, job_manager):
async def test_failed_runtime_env_setup(self, job_manager):
"""Ensure job status is correctly set as failed if job has a valid
runtime_env that fails to be set up.
"""
@ -268,12 +286,14 @@ class TestRuntimeEnv:
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id
)
status = job_manager.get_job_status(job_id)
assert "runtime_env setup failed" in status.message
def test_pass_metadata(self, job_manager):
async def test_pass_metadata(self, job_manager):
def dict_to_str(d):
return str(dict(sorted(d.items())))
@ -289,7 +309,9 @@ class TestRuntimeEnv:
# Check that we default to only the job ID and job name.
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert dict_to_str(
{JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id}
) in job_manager.get_job_logs(job_id)
@ -299,7 +321,9 @@ class TestRuntimeEnv:
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert (
dict_to_str(
{
@ -318,16 +342,21 @@ class TestRuntimeEnv:
metadata={JOB_NAME_METADATA_KEY: "custom_name"},
)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert dict_to_str(
{JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id}
) in job_manager.get_job_logs(job_id)
@pytest.mark.asyncio
class TestAsyncAPI:
def test_status_and_logs_while_blocking(self, job_manager):
async def test_status_and_logs_while_blocking(self, job_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir)
pid_file, tmp_file, job_id = await _run_hanging_command(
job_manager, tmp_dir
)
with open(pid_file, "r") as file:
pid = int(file.read())
assert psutil.pid_exists(pid), "driver subprocess should be running"
@ -336,27 +365,29 @@ class TestAsyncAPI:
with open(tmp_file, "w") as f:
print("hello", file=f)
wait_for_condition(
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached
# termination state
wait_for_condition(check_subprocess_cleaned, pid=pid)
await async_wait_for_condition(check_subprocess_cleaned, pid=pid)
def test_stop_job(self, job_manager):
async def test_stop_job(self, job_manager):
with tempfile.TemporaryDirectory() as tmp_dir:
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
_, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
assert job_manager.stop_job(job_id) is True
wait_for_condition(
await async_wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Assert re-stopping a stopped job also returns False
wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
await async_wait_for_condition(
lambda: job_manager.stop_job(job_id) is False
)
# Assert stopping non-existent job returns False
assert job_manager.stop_job(str(uuid4())) is False
def test_kill_job_actor_in_before_driver_finish(self, job_manager):
async def test_kill_job_actor_in_before_driver_finish(self, job_manager):
"""
Test submitting a long running / blocker driver script, and kill
the job supervisor actor before script returns and ensure
@ -366,20 +397,22 @@ class TestAsyncAPI:
"""
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file:
pid = int(file.read())
assert psutil.pid_exists(pid), "driver subprocess should be running"
actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached
# termination state
wait_for_condition(check_subprocess_cleaned, pid=pid)
await async_wait_for_condition(check_subprocess_cleaned, pid=pid)
def test_stop_job_in_pending(self, job_manager):
async def test_stop_job_in_pending(self, job_manager):
"""
Kick off a job that is in PENDING state, stop the job and ensure
@ -389,7 +422,7 @@ class TestAsyncAPI:
start_signal_actor = SignalActor.remote()
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(
pid_file, _, job_id = await _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
assert not os.path.exists(pid_file), (
@ -399,11 +432,11 @@ class TestAsyncAPI:
assert job_manager.stop_job(job_id) is True
# Send run signal to unblock run function
ray.get(start_signal_actor.send.remote())
wait_for_condition(
await async_wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
def test_kill_job_actor_in_pending(self, job_manager):
async def test_kill_job_actor_in_pending(self, job_manager):
"""
Kick off a job that is in PENDING state, kill the job actor and ensure
@ -413,7 +446,7 @@ class TestAsyncAPI:
start_signal_actor = SignalActor.remote()
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(
pid_file, _, job_id = await _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
@ -423,9 +456,11 @@ class TestAsyncAPI:
actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True)
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id
)
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
async def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
"""
Ensure driver scripts' subprocess is cleaned up properly when we
stopped a running job.
@ -433,21 +468,22 @@ class TestAsyncAPI:
SIGTERM first, SIGKILL after 3 seconds.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
with open(pid_file, "r") as file:
pid = int(file.read())
assert psutil.pid_exists(pid), "driver subprocess should be running"
assert job_manager.stop_job(job_id) is True
wait_for_condition(
await async_wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached
# termination state
wait_for_condition(check_subprocess_cleaned, pid=pid)
await async_wait_for_condition(check_subprocess_cleaned, pid=pid)
@pytest.mark.asyncio
class TestTailLogs:
async def _tail_and_assert_logs(
self, job_id, job_manager, expected_log="", num_iteration=5
@ -460,19 +496,17 @@ class TestTailLogs:
break
i += 1
@pytest.mark.asyncio
async def test_unknown_job(self, job_manager):
with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."):
async for _ in job_manager.tail_job_logs("unknown"):
pass
@pytest.mark.asyncio
async def test_successful_job(self, job_manager):
"""Test tailing logs for a PENDING -> RUNNING -> SUCCESSFUL job."""
start_signal_actor = SignalActor.remote()
with tempfile.TemporaryDirectory() as tmp_dir:
_, tmp_file, job_id = _run_hanging_command(
_, tmp_file, job_id = await _run_hanging_command(
job_manager, tmp_dir, start_signal_actor=start_signal_actor
)
@ -495,15 +529,14 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
wait_for_condition(
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
@pytest.mark.asyncio
async def test_failed_job(self, job_manager):
"""Test tailing logs for a job that unexpectedly exits."""
with tempfile.TemporaryDirectory() as tmp_dir:
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
await self._tail_and_assert_logs(
job_id, job_manager, expected_log="Waiting...", num_iteration=5
@ -517,13 +550,14 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_failed, job_manager=job_manager, job_id=job_id
)
@pytest.mark.asyncio
async def test_stopped_job(self, job_manager):
"""Test tailing logs for a job that unexpectedly exits."""
with tempfile.TemporaryDirectory() as tmp_dir:
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
_, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
await self._tail_and_assert_logs(
job_id, job_manager, expected_log="Waiting...", num_iteration=5
@ -536,12 +570,13 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
wait_for_condition(
await async_wait_for_condition(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
def test_logs_streaming(job_manager):
@pytest.mark.asyncio
async def test_logs_streaming(job_manager):
"""Test that logs are streamed during the job, not just at the end."""
stream_logs_script = """
@ -554,12 +589,15 @@ while True:
stream_logs_cmd = f'python -c "{stream_logs_script}"'
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))
await async_wait_for_condition(
lambda: "STREAMED" in job_manager.get_job_logs(job_id)
)
job_manager.stop_job(job_id)
def test_bootstrap_address(job_manager, monkeypatch):
@pytest.mark.asyncio
async def test_bootstrap_address(job_manager, monkeypatch):
"""Ensure we always use bootstrap address in job manager even though ray
cluster might be started with http://ip:{dashboard_port} from previous
runs.
@ -574,7 +612,9 @@ def test_bootstrap_address(job_manager, monkeypatch):
job_id = job_manager.submit_job(entrypoint=print_ray_address_cmd)
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
await async_wait_for_condition(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert "SUCCESS!" in job_manager.get_job_logs(job_id)

View file

@ -377,6 +377,34 @@ def wait_for_condition(
raise RuntimeError(message)
async def async_wait_for_condition(
condition_predictor, timeout=10, retry_interval_ms=100, **kwargs: Any
):
"""Wait until a condition is met or time out with an exception.
Args:
condition_predictor: A function that predicts the condition.
timeout: Maximum timeout in seconds.
retry_interval_ms: Retry interval in milliseconds.
Raises:
RuntimeError: If the condition is not met before the timeout expires.
"""
start = time.time()
last_ex = None
while time.time() - start <= timeout:
try:
if condition_predictor(**kwargs):
return
except Exception as ex:
last_ex = ex
await asyncio.sleep(retry_interval_ms / 1000.0)
message = "The condition wasn't met before the timeout expired."
if last_ex is not None:
message += f" Last exception: {last_ex}"
raise RuntimeError(message)
def wait_until_succeeded_without_exception(
func, exceptions, *args, timeout_ms=1000, retry_interval_ms=100, raise_last_ex=False
):