mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[jobs] Monitor jobs in the background to avoid requiring clients to poll (#22180)
This commit is contained in:
parent
8e1e783596
commit
8806b2d5c4
4 changed files with 251 additions and 143 deletions
|
@ -7,6 +7,7 @@ from ray import ray_constants
|
|||
from ray.experimental.internal_kv import (
|
||||
_internal_kv_initialized,
|
||||
_internal_kv_get,
|
||||
_internal_kv_list,
|
||||
_internal_kv_put,
|
||||
)
|
||||
from ray._private.runtime_env.packaging import parse_uri
|
||||
|
@ -21,15 +22,18 @@ CURRENT_VERSION = "1"
|
|||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
def __str__(self):
|
||||
return f"{self.value}"
|
||||
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
STOPPED = "STOPPED"
|
||||
SUCCEEDED = "SUCCEEDED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.value}"
|
||||
|
||||
def is_terminal(self):
|
||||
return self.value in {"STOPPED", "SUCCEEDED", "FAILED"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobStatusInfo:
|
||||
|
@ -58,7 +62,8 @@ class JobStatusStorageClient:
|
|||
Handles formatting of status storage key given job id.
|
||||
"""
|
||||
|
||||
JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}"
|
||||
JOB_STATUS_KEY_PREFIX = "_ray_internal_job_status"
|
||||
JOB_STATUS_KEY = f"{JOB_STATUS_KEY_PREFIX}_{{job_id}}"
|
||||
|
||||
def __init__(self):
|
||||
assert _internal_kv_initialized()
|
||||
|
@ -85,6 +90,11 @@ class JobStatusStorageClient:
|
|||
else:
|
||||
return pickle.loads(pickled_status)
|
||||
|
||||
def get_all_jobs(self) -> Dict[str, JobStatusInfo]:
|
||||
raw_job_ids = _internal_kv_list(self.JOB_STATUS_KEY_PREFIX)
|
||||
job_ids = [job_id.decode() for job_id in raw_job_ids]
|
||||
return {job_id: self.get_status(job_id) for job_id in job_ids}
|
||||
|
||||
|
||||
def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
|
||||
if not package_uri.endswith(".zip"):
|
||||
|
|
|
@ -115,11 +115,8 @@ class JobSupervisor:
|
|||
# fire and forget call from outer job manager to this actor
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
def ready(self):
|
||||
"""Dummy object ref. Return of this function represents job supervisor
|
||||
actor stated successfully with runtime_env configured, and is ready to
|
||||
move on to running state.
|
||||
"""
|
||||
def ping(self):
|
||||
"""Used to check the health of the actor."""
|
||||
pass
|
||||
|
||||
def _exec_entrypoint(self, logs_path: str) -> subprocess.Popen:
|
||||
|
@ -193,8 +190,10 @@ class JobSupervisor:
|
|||
variables.
|
||||
3) Handle concurrent events of driver execution and
|
||||
"""
|
||||
cur_status = self._get_status()
|
||||
assert cur_status.status == JobStatus.PENDING, "Run should only be called once."
|
||||
curr_status = self._status_client.get_status(self._job_id)
|
||||
assert (
|
||||
curr_status.status == JobStatus.PENDING
|
||||
), "Run should only be called once."
|
||||
|
||||
if _start_signal_actor:
|
||||
# Block in PENDING state until start signal received.
|
||||
|
@ -269,9 +268,6 @@ class JobSupervisor:
|
|||
# clean up actor after tasks are finished
|
||||
ray.actor.exit_actor()
|
||||
|
||||
def _get_status(self) -> Optional[JobStatusInfo]:
|
||||
return self._status_client.get_status(self._job_id)
|
||||
|
||||
def stop(self):
|
||||
"""Set step_event and let run() handle the rest in its asyncio.wait()."""
|
||||
self._stop_event.set()
|
||||
|
@ -288,18 +284,92 @@ class JobManager:
|
|||
# Time that we will sleep while tailing logs if no new log line is
|
||||
# available.
|
||||
LOG_TAIL_SLEEP_S = 1
|
||||
JOB_MONITOR_LOOP_PERIOD_S = 1
|
||||
|
||||
def __init__(self):
|
||||
self._status_client = JobStatusStorageClient()
|
||||
self._log_client = JobLogStorageClient()
|
||||
self._supervisor_actor_cls = ray.remote(JobSupervisor)
|
||||
|
||||
self._recover_running_jobs()
|
||||
|
||||
def _recover_running_jobs(self):
|
||||
"""Recovers all running jobs from the status client.
|
||||
|
||||
For each job, we will spawn a coroutine to monitor it.
|
||||
Each will be added to self._running_jobs and reconciled.
|
||||
"""
|
||||
all_jobs = self._status_client.get_all_jobs()
|
||||
for job_id, status_info in all_jobs.items():
|
||||
if not status_info.status.is_terminal():
|
||||
create_task(self._monitor_job(job_id))
|
||||
|
||||
def _get_actor_for_job(self, job_id: str) -> Optional[ActorHandle]:
|
||||
try:
|
||||
return ray.get_actor(self.JOB_ACTOR_NAME.format(job_id=job_id))
|
||||
except ValueError: # Ray returns ValueError for nonexistent actor.
|
||||
return None
|
||||
|
||||
async def _monitor_job(
|
||||
self, job_id: str, job_supervisor: Optional[ActorHandle] = None
|
||||
):
|
||||
"""Monitors the specified job until it enters a terminal state.
|
||||
|
||||
This is necessary because we need to handle the case where the
|
||||
JobSupervisor dies unexpectedly.
|
||||
"""
|
||||
is_alive = True
|
||||
if job_supervisor is None:
|
||||
job_supervisor = self._get_actor_for_job(job_id)
|
||||
|
||||
if job_supervisor is None:
|
||||
logger.error(f"Failed to get job supervisor for job {job_id}.")
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=(
|
||||
"Unexpected error occurred: Failed to get job supervisor."
|
||||
),
|
||||
),
|
||||
)
|
||||
is_alive = False
|
||||
|
||||
while is_alive:
|
||||
try:
|
||||
await job_supervisor.ping.remote()
|
||||
await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S)
|
||||
except Exception as e:
|
||||
is_alive = False
|
||||
if self._status_client.get_status(job_id).status.is_terminal():
|
||||
# If the job is already in a terminal state, then the actor
|
||||
# exiting is expected.
|
||||
pass
|
||||
elif isinstance(e, RuntimeEnvSetupError):
|
||||
logger.info(f"Failed to set up runtime_env for job {job_id}.")
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=(f"runtime_env setup failed: {e}"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Job supervisor for job {job_id} failed unexpectedly: {e}."
|
||||
)
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=f"Unexpected error occurred: {e}",
|
||||
),
|
||||
)
|
||||
|
||||
# Kill the actor defensively to avoid leaking actors in unexpected error cases.
|
||||
if job_supervisor is not None:
|
||||
ray.kill(job_supervisor, no_restart=True)
|
||||
|
||||
def _get_current_node_resource_key(self) -> str:
|
||||
"""Get the Ray resource key for current node.
|
||||
|
||||
|
@ -326,26 +396,6 @@ class JobManager:
|
|||
"""
|
||||
if result is None:
|
||||
return
|
||||
elif isinstance(result, RuntimeEnvSetupError):
|
||||
logger.info(f"Failed to set up runtime_env for job {job_id}.")
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=(f"runtime_env setup failed: {result}"),
|
||||
),
|
||||
)
|
||||
elif isinstance(result, Exception):
|
||||
logger.error(f"Failed to start supervisor for job {job_id}: {result}.")
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=f"Error occurred while starting the job: {result}",
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert False, "This should not be reached."
|
||||
|
||||
def submit_job(
|
||||
self,
|
||||
|
@ -394,9 +444,9 @@ class JobManager:
|
|||
|
||||
# Wait for the actor to start up asynchronously so this call always
|
||||
# returns immediately and we can catch errors with the actor starting
|
||||
# up. We may want to put this in an actor instead in the future.
|
||||
# up.
|
||||
try:
|
||||
actor = self._supervisor_actor_cls.options(
|
||||
supervisor = self._supervisor_actor_cls.options(
|
||||
lifetime="detached",
|
||||
name=self.JOB_ACTOR_NAME.format(job_id=job_id),
|
||||
num_cpus=0,
|
||||
|
@ -407,26 +457,26 @@ class JobManager:
|
|||
},
|
||||
runtime_env=runtime_env,
|
||||
).remote(job_id, entrypoint, metadata or {})
|
||||
actor.run.remote(_start_signal_actor=_start_signal_actor)
|
||||
supervisor.run.remote(_start_signal_actor=_start_signal_actor)
|
||||
|
||||
def callback(result: Optional[Exception]):
|
||||
return self._handle_supervisor_startup(job_id, result)
|
||||
|
||||
actor.ready.remote()._on_completed(callback)
|
||||
# Monitor the job in the background so we can detect errors without
|
||||
# requiring a client to poll.
|
||||
create_task(self._monitor_job(job_id, job_supervisor=supervisor))
|
||||
except Exception as e:
|
||||
self._handle_supervisor_startup(job_id, e)
|
||||
self._status_client.put_status(
|
||||
job_id,
|
||||
JobStatusInfo(
|
||||
status=JobStatus.FAILED,
|
||||
message=f"Failed to start job supervisor: {e}.",
|
||||
),
|
||||
)
|
||||
|
||||
return job_id
|
||||
|
||||
def stop_job(self, job_id) -> bool:
|
||||
"""Request job to exit, fire and forget.
|
||||
"""Request a job to exit, fire and forget.
|
||||
|
||||
Args:
|
||||
job_id: ID of the job.
|
||||
Returns:
|
||||
stopped:
|
||||
True if there's running job
|
||||
False if no running job found
|
||||
Returns whether or not the job was running.
|
||||
"""
|
||||
job_supervisor_actor = self._get_actor_for_job(job_id)
|
||||
if job_supervisor_actor is not None:
|
||||
|
@ -438,35 +488,15 @@ class JobManager:
|
|||
return False
|
||||
|
||||
def get_job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get latest status of a job. If job supervisor actor is no longer
|
||||
alive, it will also attempt to make adjustments needed to bring job
|
||||
to correct terminiation state.
|
||||
|
||||
All job status is stored and read only from GCS.
|
||||
|
||||
Args:
|
||||
job_id: ID of the job.
|
||||
Returns:
|
||||
job_status: Latest known job status
|
||||
"""
|
||||
job_supervisor_actor = self._get_actor_for_job(job_id)
|
||||
if job_supervisor_actor is None:
|
||||
# Job actor either exited or failed, we need to ensure never
|
||||
# left job in non-terminal status in case actor failed without
|
||||
# updating GCS with latest status.
|
||||
last_status = self._status_client.get_status(job_id)
|
||||
if last_status and last_status.status in {
|
||||
JobStatus.PENDING,
|
||||
JobStatus.RUNNING,
|
||||
}:
|
||||
self._status_client.put_status(job_id, JobStatus.FAILED)
|
||||
|
||||
"""Get latest status of a job."""
|
||||
return self._status_client.get_status(job_id)
|
||||
|
||||
def get_job_logs(self, job_id: str) -> str:
|
||||
"""Get all logs produced by a job."""
|
||||
return self._log_client.get_logs(job_id)
|
||||
|
||||
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
|
||||
"""Return an iterator following the logs of a job."""
|
||||
if self.get_job_status(job_id) is None:
|
||||
raise RuntimeError(f"Job '{job_id}' does not exist.")
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import os
|
||||
import psutil
|
||||
import tempfile
|
||||
|
@ -14,7 +15,7 @@ from ray.dashboard.modules.job.common import (
|
|||
JOB_NAME_METADATA_KEY,
|
||||
)
|
||||
from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager
|
||||
from ray._private.test_utils import SignalActor, wait_for_condition
|
||||
from ray._private.test_utils import SignalActor, async_wait_for_condition
|
||||
|
||||
TEST_NAMESPACE = "jobs_test_namespace"
|
||||
|
||||
|
@ -40,7 +41,7 @@ def _driver_script_path(file_name: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
||||
async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
||||
tmp_file = os.path.join(tmp_dir, "hello")
|
||||
pid_file = os.path.join(tmp_dir, "pid")
|
||||
|
||||
|
@ -61,10 +62,14 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
|
|||
assert status.status == JobStatus.PENDING
|
||||
logs = job_manager.get_job_logs(job_id)
|
||||
assert logs == ""
|
||||
await asyncio.sleep(0.01)
|
||||
else:
|
||||
wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id)
|
||||
|
||||
wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id))
|
||||
await async_wait_for_condition(
|
||||
check_job_running, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
await async_wait_for_condition(
|
||||
lambda: "Waiting..." in job_manager.get_job_logs(job_id)
|
||||
)
|
||||
|
||||
return pid_file, tmp_file, job_id
|
||||
|
||||
|
@ -112,40 +117,50 @@ def test_generate_job_id():
|
|||
assert len(ids) == 10000
|
||||
|
||||
|
||||
def test_pass_job_id(job_manager):
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_job_id(job_manager):
|
||||
job_id = "my_custom_id"
|
||||
|
||||
returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
|
||||
assert returned_id == job_id
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
# Check that the same job_id is rejected.
|
||||
with pytest.raises(RuntimeError):
|
||||
job_manager.submit_job(entrypoint="echo hello", job_id=job_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestShellScriptExecution:
|
||||
def test_submit_basic_echo(self, job_manager):
|
||||
async def test_submit_basic_echo(self, job_manager):
|
||||
job_id = job_manager.submit_job(entrypoint="echo hello")
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert job_manager.get_job_logs(job_id) == "hello\n"
|
||||
|
||||
def test_submit_stderr(self, job_manager):
|
||||
async def test_submit_stderr(self, job_manager):
|
||||
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert job_manager.get_job_logs(job_id) == "error\n"
|
||||
|
||||
def test_submit_ls_grep(self, job_manager):
|
||||
async def test_submit_ls_grep(self, job_manager):
|
||||
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
|
||||
job_id = job_manager.submit_job(entrypoint=grep_cmd)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
|
||||
|
||||
def test_subprocess_exception(self, job_manager):
|
||||
async def test_subprocess_exception(self, job_manager):
|
||||
"""
|
||||
Run a python script with exception, ensure:
|
||||
1) Job status is marked as failed
|
||||
|
@ -165,26 +180,25 @@ class TestShellScriptExecution:
|
|||
|
||||
return job_manager._get_actor_for_job(job_id) is None
|
||||
|
||||
wait_for_condition(cleaned_up)
|
||||
await async_wait_for_condition(cleaned_up)
|
||||
|
||||
def test_submit_with_s3_runtime_env(self, job_manager):
|
||||
async def test_submit_with_s3_runtime_env(self, job_manager):
|
||||
job_id = job_manager.submit_job(
|
||||
entrypoint="python script.py",
|
||||
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
|
||||
)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert (
|
||||
job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRuntimeEnv:
|
||||
def test_inheritance(self, job_manager):
|
||||
# Test that the driver and actors/tasks inherit the right runtime_env.
|
||||
pass
|
||||
|
||||
def test_pass_env_var(self, job_manager):
|
||||
async def test_pass_env_var(self, job_manager):
|
||||
"""Test we can pass env vars in the subprocess that executes job's
|
||||
driver script.
|
||||
"""
|
||||
|
@ -193,10 +207,12 @@ class TestRuntimeEnv:
|
|||
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
|
||||
)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert job_manager.get_job_logs(job_id) == "233\n"
|
||||
|
||||
def test_multiple_runtime_envs(self, job_manager):
|
||||
async def test_multiple_runtime_envs(self, job_manager):
|
||||
# Test that you can run two jobs in different envs without conflict.
|
||||
job_id_1 = job_manager.submit_job(
|
||||
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
|
||||
|
@ -205,7 +221,7 @@ class TestRuntimeEnv:
|
|||
},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_1
|
||||
)
|
||||
logs = job_manager.get_job_logs(job_id_1)
|
||||
|
@ -220,7 +236,7 @@ class TestRuntimeEnv:
|
|||
},
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id_2
|
||||
)
|
||||
logs = job_manager.get_job_logs(job_id_2)
|
||||
|
@ -228,7 +244,7 @@ class TestRuntimeEnv:
|
|||
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs
|
||||
) # noqa: E501
|
||||
|
||||
def test_env_var_and_driver_job_config_warning(self, job_manager):
|
||||
async def test_env_var_and_driver_job_config_warning(self, job_manager):
|
||||
"""Ensure we got error message from worker.py and job logs
|
||||
if user provided runtime_env in both driver script and submit()
|
||||
"""
|
||||
|
@ -239,14 +255,16 @@ class TestRuntimeEnv:
|
|||
},
|
||||
)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
logs = job_manager.get_job_logs(job_id)
|
||||
assert logs.startswith(
|
||||
"Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided"
|
||||
)
|
||||
assert "JOB_1_VAR" in logs
|
||||
|
||||
def test_failed_runtime_env_validation(self, job_manager):
|
||||
async def test_failed_runtime_env_validation(self, job_manager):
|
||||
"""Ensure job status is correctly set as failed if job has an invalid
|
||||
runtime_env.
|
||||
"""
|
||||
|
@ -259,7 +277,7 @@ class TestRuntimeEnv:
|
|||
assert status.status == JobStatus.FAILED
|
||||
assert "path_not_exist is not a valid URI" in status.message
|
||||
|
||||
def test_failed_runtime_env_setup(self, job_manager):
|
||||
async def test_failed_runtime_env_setup(self, job_manager):
|
||||
"""Ensure job status is correctly set as failed if job has a valid
|
||||
runtime_env that fails to be set up.
|
||||
"""
|
||||
|
@ -268,12 +286,14 @@ class TestRuntimeEnv:
|
|||
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
|
||||
)
|
||||
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
status = job_manager.get_job_status(job_id)
|
||||
assert "runtime_env setup failed" in status.message
|
||||
|
||||
def test_pass_metadata(self, job_manager):
|
||||
async def test_pass_metadata(self, job_manager):
|
||||
def dict_to_str(d):
|
||||
return str(dict(sorted(d.items())))
|
||||
|
||||
|
@ -289,7 +309,9 @@ class TestRuntimeEnv:
|
|||
# Check that we default to only the job ID and job name.
|
||||
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert dict_to_str(
|
||||
{JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id}
|
||||
) in job_manager.get_job_logs(job_id)
|
||||
|
@ -299,7 +321,9 @@ class TestRuntimeEnv:
|
|||
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
|
||||
)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert (
|
||||
dict_to_str(
|
||||
{
|
||||
|
@ -318,16 +342,21 @@ class TestRuntimeEnv:
|
|||
metadata={JOB_NAME_METADATA_KEY: "custom_name"},
|
||||
)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert dict_to_str(
|
||||
{JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id}
|
||||
) in job_manager.get_job_logs(job_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAsyncAPI:
|
||||
def test_status_and_logs_while_blocking(self, job_manager):
|
||||
async def test_status_and_logs_while_blocking(self, job_manager):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
pid_file, tmp_file, job_id = await _run_hanging_command(
|
||||
job_manager, tmp_dir
|
||||
)
|
||||
with open(pid_file, "r") as file:
|
||||
pid = int(file.read())
|
||||
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||
|
@ -336,27 +365,29 @@ class TestAsyncAPI:
|
|||
with open(tmp_file, "w") as f:
|
||||
print("hello", file=f)
|
||||
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
# Ensure driver subprocess gets cleaned up after job reached
|
||||
# termination state
|
||||
wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
await async_wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
|
||||
def test_stop_job(self, job_manager):
|
||||
async def test_stop_job(self, job_manager):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
_, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
|
||||
|
||||
assert job_manager.stop_job(job_id) is True
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
# Assert re-stopping a stopped job also returns False
|
||||
wait_for_condition(lambda: job_manager.stop_job(job_id) is False)
|
||||
await async_wait_for_condition(
|
||||
lambda: job_manager.stop_job(job_id) is False
|
||||
)
|
||||
# Assert stopping non-existent job returns False
|
||||
assert job_manager.stop_job(str(uuid4())) is False
|
||||
|
||||
def test_kill_job_actor_in_before_driver_finish(self, job_manager):
|
||||
async def test_kill_job_actor_in_before_driver_finish(self, job_manager):
|
||||
"""
|
||||
Test submitting a long running / blocker driver script, and kill
|
||||
the job supervisor actor before script returns and ensure
|
||||
|
@ -366,20 +397,22 @@ class TestAsyncAPI:
|
|||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
|
||||
with open(pid_file, "r") as file:
|
||||
pid = int(file.read())
|
||||
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||
|
||||
actor = job_manager._get_actor_for_job(job_id)
|
||||
ray.kill(actor, no_restart=True)
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
# Ensure driver subprocess gets cleaned up after job reached
|
||||
# termination state
|
||||
wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
await async_wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
|
||||
def test_stop_job_in_pending(self, job_manager):
|
||||
async def test_stop_job_in_pending(self, job_manager):
|
||||
"""
|
||||
Kick off a job that is in PENDING state, stop the job and ensure
|
||||
|
||||
|
@ -389,7 +422,7 @@ class TestAsyncAPI:
|
|||
start_signal_actor = SignalActor.remote()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(
|
||||
pid_file, _, job_id = await _run_hanging_command(
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||
)
|
||||
assert not os.path.exists(pid_file), (
|
||||
|
@ -399,11 +432,11 @@ class TestAsyncAPI:
|
|||
assert job_manager.stop_job(job_id) is True
|
||||
# Send run signal to unblock run function
|
||||
ray.get(start_signal_actor.send.remote())
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
def test_kill_job_actor_in_pending(self, job_manager):
|
||||
async def test_kill_job_actor_in_pending(self, job_manager):
|
||||
"""
|
||||
Kick off a job that is in PENDING state, kill the job actor and ensure
|
||||
|
||||
|
@ -413,7 +446,7 @@ class TestAsyncAPI:
|
|||
start_signal_actor = SignalActor.remote()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(
|
||||
pid_file, _, job_id = await _run_hanging_command(
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||
)
|
||||
|
||||
|
@ -423,9 +456,11 @@ class TestAsyncAPI:
|
|||
|
||||
actor = job_manager._get_actor_for_job(job_id)
|
||||
ray.kill(actor, no_restart=True)
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
|
||||
async def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager):
|
||||
"""
|
||||
Ensure driver scripts' subprocess is cleaned up properly when we
|
||||
stopped a running job.
|
||||
|
@ -433,21 +468,22 @@ class TestAsyncAPI:
|
|||
SIGTERM first, SIGKILL after 3 seconds.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
|
||||
with open(pid_file, "r") as file:
|
||||
pid = int(file.read())
|
||||
assert psutil.pid_exists(pid), "driver subprocess should be running"
|
||||
|
||||
assert job_manager.stop_job(job_id) is True
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
# Ensure driver subprocess gets cleaned up after job reached
|
||||
# termination state
|
||||
wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
await async_wait_for_condition(check_subprocess_cleaned, pid=pid)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTailLogs:
|
||||
async def _tail_and_assert_logs(
|
||||
self, job_id, job_manager, expected_log="", num_iteration=5
|
||||
|
@ -460,19 +496,17 @@ class TestTailLogs:
|
|||
break
|
||||
i += 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_job(self, job_manager):
|
||||
with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."):
|
||||
async for _ in job_manager.tail_job_logs("unknown"):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_job(self, job_manager):
|
||||
"""Test tailing logs for a PENDING -> RUNNING -> SUCCESSFUL job."""
|
||||
start_signal_actor = SignalActor.remote()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_, tmp_file, job_id = _run_hanging_command(
|
||||
_, tmp_file, job_id = await _run_hanging_command(
|
||||
job_manager, tmp_dir, start_signal_actor=start_signal_actor
|
||||
)
|
||||
|
||||
|
@ -495,15 +529,14 @@ class TestTailLogs:
|
|||
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
print(lines, end="")
|
||||
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_job(self, job_manager):
|
||||
"""Test tailing logs for a job that unexpectedly exits."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
|
||||
|
||||
await self._tail_and_assert_logs(
|
||||
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||
|
@ -517,13 +550,14 @@ class TestTailLogs:
|
|||
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
print(lines, end="")
|
||||
|
||||
wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_failed, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stopped_job(self, job_manager):
|
||||
"""Test tailing logs for a job that unexpectedly exits."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_, _, job_id = _run_hanging_command(job_manager, tmp_dir)
|
||||
_, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
|
||||
|
||||
await self._tail_and_assert_logs(
|
||||
job_id, job_manager, expected_log="Waiting...", num_iteration=5
|
||||
|
@ -536,12 +570,13 @@ class TestTailLogs:
|
|||
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
|
||||
print(lines, end="")
|
||||
|
||||
wait_for_condition(
|
||||
await async_wait_for_condition(
|
||||
check_job_stopped, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
|
||||
|
||||
def test_logs_streaming(job_manager):
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_streaming(job_manager):
|
||||
"""Test that logs are streamed during the job, not just at the end."""
|
||||
|
||||
stream_logs_script = """
|
||||
|
@ -554,12 +589,15 @@ while True:
|
|||
stream_logs_cmd = f'python -c "{stream_logs_script}"'
|
||||
|
||||
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
|
||||
wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id))
|
||||
await async_wait_for_condition(
|
||||
lambda: "STREAMED" in job_manager.get_job_logs(job_id)
|
||||
)
|
||||
|
||||
job_manager.stop_job(job_id)
|
||||
|
||||
|
||||
def test_bootstrap_address(job_manager, monkeypatch):
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_address(job_manager, monkeypatch):
|
||||
"""Ensure we always use bootstrap address in job manager even though ray
|
||||
cluster might be started with http://ip:{dashboard_port} from previous
|
||||
runs.
|
||||
|
@ -574,7 +612,9 @@ def test_bootstrap_address(job_manager, monkeypatch):
|
|||
|
||||
job_id = job_manager.submit_job(entrypoint=print_ray_address_cmd)
|
||||
|
||||
wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id)
|
||||
await async_wait_for_condition(
|
||||
check_job_succeeded, job_manager=job_manager, job_id=job_id
|
||||
)
|
||||
assert "SUCCESS!" in job_manager.get_job_logs(job_id)
|
||||
|
||||
|
||||
|
|
|
@ -377,6 +377,34 @@ def wait_for_condition(
|
|||
raise RuntimeError(message)
|
||||
|
||||
|
||||
async def async_wait_for_condition(
|
||||
condition_predictor, timeout=10, retry_interval_ms=100, **kwargs: Any
|
||||
):
|
||||
"""Wait until a condition is met or time out with an exception.
|
||||
|
||||
Args:
|
||||
condition_predictor: A function that predicts the condition.
|
||||
timeout: Maximum timeout in seconds.
|
||||
retry_interval_ms: Retry interval in milliseconds.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the condition is not met before the timeout expires.
|
||||
"""
|
||||
start = time.time()
|
||||
last_ex = None
|
||||
while time.time() - start <= timeout:
|
||||
try:
|
||||
if condition_predictor(**kwargs):
|
||||
return
|
||||
except Exception as ex:
|
||||
last_ex = ex
|
||||
await asyncio.sleep(retry_interval_ms / 1000.0)
|
||||
message = "The condition wasn't met before the timeout expired."
|
||||
if last_ex is not None:
|
||||
message += f" Last exception: {last_ex}"
|
||||
raise RuntimeError(message)
|
||||
|
||||
|
||||
def wait_until_succeeded_without_exception(
|
||||
func, exceptions, *args, timeout_ms=1000, retry_interval_ms=100, raise_last_ex=False
|
||||
):
|
||||
|
|
Loading…
Add table
Reference in a new issue