diff --git a/dashboard/modules/job/common.py b/dashboard/modules/job/common.py index 44b218445..eb00d67d9 100644 --- a/dashboard/modules/job/common.py +++ b/dashboard/modules/job/common.py @@ -7,6 +7,7 @@ from ray import ray_constants from ray.experimental.internal_kv import ( _internal_kv_initialized, _internal_kv_get, + _internal_kv_list, _internal_kv_put, ) from ray._private.runtime_env.packaging import parse_uri @@ -21,15 +22,18 @@ CURRENT_VERSION = "1" class JobStatus(str, Enum): - def __str__(self): - return f"{self.value}" - PENDING = "PENDING" RUNNING = "RUNNING" STOPPED = "STOPPED" SUCCEEDED = "SUCCEEDED" FAILED = "FAILED" + def __str__(self): + return f"{self.value}" + + def is_terminal(self): + return self.value in {"STOPPED", "SUCCEEDED", "FAILED"} + @dataclass class JobStatusInfo: @@ -58,7 +62,8 @@ class JobStatusStorageClient: Handles formatting of status storage key given job id. """ - JOB_STATUS_KEY = "_ray_internal_job_status_{job_id}" + JOB_STATUS_KEY_PREFIX = "_ray_internal_job_status" + JOB_STATUS_KEY = f"{JOB_STATUS_KEY_PREFIX}_{{job_id}}" def __init__(self): assert _internal_kv_initialized() @@ -85,6 +90,11 @@ class JobStatusStorageClient: else: return pickle.loads(pickled_status) + def get_all_jobs(self) -> Dict[str, JobStatusInfo]: + raw_job_ids = _internal_kv_list(self.JOB_STATUS_KEY_PREFIX) + job_ids = [job_id.decode() for job_id in raw_job_ids] + return {job_id: self.get_status(job_id) for job_id in job_ids} + def uri_to_http_components(package_uri: str) -> Tuple[str, str]: if not package_uri.endswith(".zip"): diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index 68713877c..ad217e1d2 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -115,11 +115,8 @@ class JobSupervisor: # fire and forget call from outer job manager to this actor self._stop_event = asyncio.Event() - def ready(self): - """Dummy object ref. Return of this function represents job supervisor - actor stated successfully with runtime_env configured, and is ready to - move on to running state. - """ + def ping(self): + """Used to check the health of the actor.""" pass def _exec_entrypoint(self, logs_path: str) -> subprocess.Popen: @@ -193,8 +190,10 @@ class JobSupervisor: variables. 3) Handle concurrent events of driver execution and """ - cur_status = self._get_status() - assert cur_status.status == JobStatus.PENDING, "Run should only be called once." + curr_status = self._status_client.get_status(self._job_id) + assert ( + curr_status.status == JobStatus.PENDING + ), "Run should only be called once." if _start_signal_actor: # Block in PENDING state until start signal received. @@ -269,9 +268,6 @@ class JobSupervisor: # clean up actor after tasks are finished ray.actor.exit_actor() - def _get_status(self) -> Optional[JobStatusInfo]: - return self._status_client.get_status(self._job_id) - def stop(self): """Set step_event and let run() handle the rest in its asyncio.wait().""" self._stop_event.set() @@ -288,18 +284,92 @@ class JobManager: # Time that we will sleep while tailing logs if no new log line is # available. LOG_TAIL_SLEEP_S = 1 + JOB_MONITOR_LOOP_PERIOD_S = 1 def __init__(self): self._status_client = JobStatusStorageClient() self._log_client = JobLogStorageClient() self._supervisor_actor_cls = ray.remote(JobSupervisor) + self._recover_running_jobs() + + def _recover_running_jobs(self): + """Recovers all running jobs from the status client. + + For each job, we will spawn a coroutine to monitor it. + Each will be added to self._running_jobs and reconciled. + """ + all_jobs = self._status_client.get_all_jobs() + for job_id, status_info in all_jobs.items(): + if not status_info.status.is_terminal(): + create_task(self._monitor_job(job_id)) + def _get_actor_for_job(self, job_id: str) -> Optional[ActorHandle]: try: return ray.get_actor(self.JOB_ACTOR_NAME.format(job_id=job_id)) except ValueError: # Ray returns ValueError for nonexistent actor. return None + async def _monitor_job( + self, job_id: str, job_supervisor: Optional[ActorHandle] = None + ): + """Monitors the specified job until it enters a terminal state. + + This is necessary because we need to handle the case where the + JobSupervisor dies unexpectedly. + """ + is_alive = True + if job_supervisor is None: + job_supervisor = self._get_actor_for_job(job_id) + + if job_supervisor is None: + logger.error(f"Failed to get job supervisor for job {job_id}.") + self._status_client.put_status( + job_id, + JobStatusInfo( + status=JobStatus.FAILED, + message=( + "Unexpected error occurred: Failed to get job supervisor." + ), + ), + ) + is_alive = False + + while is_alive: + try: + await job_supervisor.ping.remote() + await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S) + except Exception as e: + is_alive = False + if self._status_client.get_status(job_id).status.is_terminal(): + # If the job is already in a terminal state, then the actor + # exiting is expected. + pass + elif isinstance(e, RuntimeEnvSetupError): + logger.info(f"Failed to set up runtime_env for job {job_id}.") + self._status_client.put_status( + job_id, + JobStatusInfo( + status=JobStatus.FAILED, + message=(f"runtime_env setup failed: {e}"), + ), + ) + else: + logger.warning( + f"Job supervisor for job {job_id} failed unexpectedly: {e}." + ) + self._status_client.put_status( + job_id, + JobStatusInfo( + status=JobStatus.FAILED, + message=f"Unexpected error occurred: {e}", + ), + ) + + # Kill the actor defensively to avoid leaking actors in unexpected error cases. + if job_supervisor is not None: + ray.kill(job_supervisor, no_restart=True) + def _get_current_node_resource_key(self) -> str: """Get the Ray resource key for current node. @@ -326,26 +396,6 @@ class JobManager: """ if result is None: return - elif isinstance(result, RuntimeEnvSetupError): - logger.info(f"Failed to set up runtime_env for job {job_id}.") - self._status_client.put_status( - job_id, - JobStatusInfo( - status=JobStatus.FAILED, - message=(f"runtime_env setup failed: {result}"), - ), - ) - elif isinstance(result, Exception): - logger.error(f"Failed to start supervisor for job {job_id}: {result}.") - self._status_client.put_status( - job_id, - JobStatusInfo( - status=JobStatus.FAILED, - message=f"Error occurred while starting the job: {result}", - ), - ) - else: - assert False, "This should not be reached." def submit_job( self, @@ -394,9 +444,9 @@ class JobManager: # Wait for the actor to start up asynchronously so this call always # returns immediately and we can catch errors with the actor starting - # up. We may want to put this in an actor instead in the future. + # up. try: - actor = self._supervisor_actor_cls.options( + supervisor = self._supervisor_actor_cls.options( lifetime="detached", name=self.JOB_ACTOR_NAME.format(job_id=job_id), num_cpus=0, @@ -407,26 +457,26 @@ class JobManager: }, runtime_env=runtime_env, ).remote(job_id, entrypoint, metadata or {}) - actor.run.remote(_start_signal_actor=_start_signal_actor) + supervisor.run.remote(_start_signal_actor=_start_signal_actor) - def callback(result: Optional[Exception]): - return self._handle_supervisor_startup(job_id, result) - - actor.ready.remote()._on_completed(callback) + # Monitor the job in the background so we can detect errors without + # requiring a client to poll. + create_task(self._monitor_job(job_id, job_supervisor=supervisor)) except Exception as e: - self._handle_supervisor_startup(job_id, e) + self._status_client.put_status( + job_id, + JobStatusInfo( + status=JobStatus.FAILED, + message=f"Failed to start job supervisor: {e}.", + ), + ) return job_id def stop_job(self, job_id) -> bool: - """Request job to exit, fire and forget. + """Request a job to exit, fire and forget. - Args: - job_id: ID of the job. - Returns: - stopped: - True if there's running job - False if no running job found + Returns whether or not the job was running. """ job_supervisor_actor = self._get_actor_for_job(job_id) if job_supervisor_actor is not None: @@ -438,35 +488,15 @@ class JobManager: return False def get_job_status(self, job_id: str) -> Optional[JobStatus]: - """Get latest status of a job. If job supervisor actor is no longer - alive, it will also attempt to make adjustments needed to bring job - to correct terminiation state. - - All job status is stored and read only from GCS. - - Args: - job_id: ID of the job. - Returns: - job_status: Latest known job status - """ - job_supervisor_actor = self._get_actor_for_job(job_id) - if job_supervisor_actor is None: - # Job actor either exited or failed, we need to ensure never - # left job in non-terminal status in case actor failed without - # updating GCS with latest status. - last_status = self._status_client.get_status(job_id) - if last_status and last_status.status in { - JobStatus.PENDING, - JobStatus.RUNNING, - }: - self._status_client.put_status(job_id, JobStatus.FAILED) - + """Get latest status of a job.""" return self._status_client.get_status(job_id) def get_job_logs(self, job_id: str) -> str: + """Get all logs produced by a job.""" return self._log_client.get_logs(job_id) async def tail_job_logs(self, job_id: str) -> Iterator[str]: + """Return an iterator following the logs of a job.""" if self.get_job_status(job_id) is None: raise RuntimeError(f"Job '{job_id}' does not exist.") diff --git a/dashboard/modules/job/tests/test_job_manager.py b/dashboard/modules/job/tests/test_job_manager.py index 8c566a17b..29fd46d0a 100644 --- a/dashboard/modules/job/tests/test_job_manager.py +++ b/dashboard/modules/job/tests/test_job_manager.py @@ -1,3 +1,4 @@ +import asyncio import os import psutil import tempfile @@ -14,7 +15,7 @@ from ray.dashboard.modules.job.common import ( JOB_NAME_METADATA_KEY, ) from ray.dashboard.modules.job.job_manager import generate_job_id, JobManager -from ray._private.test_utils import SignalActor, wait_for_condition +from ray._private.test_utils import SignalActor, async_wait_for_condition TEST_NAMESPACE = "jobs_test_namespace" @@ -40,7 +41,7 @@ def _driver_script_path(file_name: str) -> str: ) -def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None): +async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None): tmp_file = os.path.join(tmp_dir, "hello") pid_file = os.path.join(tmp_dir, "pid") @@ -61,10 +62,14 @@ def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None): assert status.status == JobStatus.PENDING logs = job_manager.get_job_logs(job_id) assert logs == "" + await asyncio.sleep(0.01) else: - wait_for_condition(check_job_running, job_manager=job_manager, job_id=job_id) - - wait_for_condition(lambda: "Waiting..." in job_manager.get_job_logs(job_id)) + await async_wait_for_condition( + check_job_running, job_manager=job_manager, job_id=job_id + ) + await async_wait_for_condition( + lambda: "Waiting..." in job_manager.get_job_logs(job_id) + ) return pid_file, tmp_file, job_id @@ -112,40 +117,50 @@ def test_generate_job_id(): assert len(ids) == 10000 -def test_pass_job_id(job_manager): +@pytest.mark.asyncio +async def test_pass_job_id(job_manager): job_id = "my_custom_id" returned_id = job_manager.submit_job(entrypoint="echo hello", job_id=job_id) assert returned_id == job_id - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) # Check that the same job_id is rejected. with pytest.raises(RuntimeError): job_manager.submit_job(entrypoint="echo hello", job_id=job_id) +@pytest.mark.asyncio class TestShellScriptExecution: - def test_submit_basic_echo(self, job_manager): + async def test_submit_basic_echo(self, job_manager): job_id = job_manager.submit_job(entrypoint="echo hello") - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert job_manager.get_job_logs(job_id) == "hello\n" - def test_submit_stderr(self, job_manager): + async def test_submit_stderr(self, job_manager): job_id = job_manager.submit_job(entrypoint="echo error 1>&2") - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert job_manager.get_job_logs(job_id) == "error\n" - def test_submit_ls_grep(self, job_manager): + async def test_submit_ls_grep(self, job_manager): grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py" job_id = job_manager.submit_job(entrypoint=grep_cmd) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n" - def test_subprocess_exception(self, job_manager): + async def test_subprocess_exception(self, job_manager): """ Run a python script with exception, ensure: 1) Job status is marked as failed @@ -165,26 +180,25 @@ class TestShellScriptExecution: return job_manager._get_actor_for_job(job_id) is None - wait_for_condition(cleaned_up) + await async_wait_for_condition(cleaned_up) - def test_submit_with_s3_runtime_env(self, job_manager): + async def test_submit_with_s3_runtime_env(self, job_manager): job_id = job_manager.submit_job( entrypoint="python script.py", runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"}, ) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert ( job_manager.get_job_logs(job_id) == "Executing main() from script.py !!\n" ) +@pytest.mark.asyncio class TestRuntimeEnv: - def test_inheritance(self, job_manager): - # Test that the driver and actors/tasks inherit the right runtime_env. - pass - - def test_pass_env_var(self, job_manager): + async def test_pass_env_var(self, job_manager): """Test we can pass env vars in the subprocess that executes job's driver script. """ @@ -193,10 +207,12 @@ class TestRuntimeEnv: runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}}, ) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert job_manager.get_job_logs(job_id) == "233\n" - def test_multiple_runtime_envs(self, job_manager): + async def test_multiple_runtime_envs(self, job_manager): # Test that you can run two jobs in different envs without conflict. job_id_1 = job_manager.submit_job( entrypoint=f"python {_driver_script_path('print_runtime_env.py')}", @@ -205,7 +221,7 @@ class TestRuntimeEnv: }, ) - wait_for_condition( + await async_wait_for_condition( check_job_succeeded, job_manager=job_manager, job_id=job_id_1 ) logs = job_manager.get_job_logs(job_id_1) @@ -220,7 +236,7 @@ class TestRuntimeEnv: }, ) - wait_for_condition( + await async_wait_for_condition( check_job_succeeded, job_manager=job_manager, job_id=job_id_2 ) logs = job_manager.get_job_logs(job_id_2) @@ -228,7 +244,7 @@ class TestRuntimeEnv: "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_2_VAR'}}" in logs ) # noqa: E501 - def test_env_var_and_driver_job_config_warning(self, job_manager): + async def test_env_var_and_driver_job_config_warning(self, job_manager): """Ensure we got error message from worker.py and job logs if user provided runtime_env in both driver script and submit() """ @@ -239,14 +255,16 @@ class TestRuntimeEnv: }, ) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) logs = job_manager.get_job_logs(job_id) assert logs.startswith( "Both RAY_JOB_CONFIG_JSON_ENV_VAR and ray.init(runtime_env) " "are provided" ) assert "JOB_1_VAR" in logs - def test_failed_runtime_env_validation(self, job_manager): + async def test_failed_runtime_env_validation(self, job_manager): """Ensure job status is correctly set as failed if job has an invalid runtime_env. """ @@ -259,7 +277,7 @@ class TestRuntimeEnv: assert status.status == JobStatus.FAILED assert "path_not_exist is not a valid URI" in status.message - def test_failed_runtime_env_setup(self, job_manager): + async def test_failed_runtime_env_setup(self, job_manager): """Ensure job status is correctly set as failed if job has a valid runtime_env that fails to be set up. """ @@ -268,12 +286,14 @@ class TestRuntimeEnv: entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"} ) - wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_failed, job_manager=job_manager, job_id=job_id + ) status = job_manager.get_job_status(job_id) assert "runtime_env setup failed" in status.message - def test_pass_metadata(self, job_manager): + async def test_pass_metadata(self, job_manager): def dict_to_str(d): return str(dict(sorted(d.items()))) @@ -289,7 +309,9 @@ class TestRuntimeEnv: # Check that we default to only the job ID and job name. job_id = job_manager.submit_job(entrypoint=print_metadata_cmd) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert dict_to_str( {JOB_NAME_METADATA_KEY: job_id, JOB_ID_METADATA_KEY: job_id} ) in job_manager.get_job_logs(job_id) @@ -299,7 +321,9 @@ class TestRuntimeEnv: entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"} ) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert ( dict_to_str( { @@ -318,16 +342,21 @@ class TestRuntimeEnv: metadata={JOB_NAME_METADATA_KEY: "custom_name"}, ) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert dict_to_str( {JOB_NAME_METADATA_KEY: "custom_name", JOB_ID_METADATA_KEY: job_id} ) in job_manager.get_job_logs(job_id) +@pytest.mark.asyncio class TestAsyncAPI: - def test_status_and_logs_while_blocking(self, job_manager): + async def test_status_and_logs_while_blocking(self, job_manager): with tempfile.TemporaryDirectory() as tmp_dir: - pid_file, tmp_file, job_id = _run_hanging_command(job_manager, tmp_dir) + pid_file, tmp_file, job_id = await _run_hanging_command( + job_manager, tmp_dir + ) with open(pid_file, "r") as file: pid = int(file.read()) assert psutil.pid_exists(pid), "driver subprocess should be running" @@ -336,27 +365,29 @@ class TestAsyncAPI: with open(tmp_file, "w") as f: print("hello", file=f) - wait_for_condition( + await async_wait_for_condition( check_job_succeeded, job_manager=job_manager, job_id=job_id ) # Ensure driver subprocess gets cleaned up after job reached # termination state - wait_for_condition(check_subprocess_cleaned, pid=pid) + await async_wait_for_condition(check_subprocess_cleaned, pid=pid) - def test_stop_job(self, job_manager): + async def test_stop_job(self, job_manager): with tempfile.TemporaryDirectory() as tmp_dir: - _, _, job_id = _run_hanging_command(job_manager, tmp_dir) + _, _, job_id = await _run_hanging_command(job_manager, tmp_dir) assert job_manager.stop_job(job_id) is True - wait_for_condition( + await async_wait_for_condition( check_job_stopped, job_manager=job_manager, job_id=job_id ) # Assert re-stopping a stopped job also returns False - wait_for_condition(lambda: job_manager.stop_job(job_id) is False) + await async_wait_for_condition( + lambda: job_manager.stop_job(job_id) is False + ) # Assert stopping non-existent job returns False assert job_manager.stop_job(str(uuid4())) is False - def test_kill_job_actor_in_before_driver_finish(self, job_manager): + async def test_kill_job_actor_in_before_driver_finish(self, job_manager): """ Test submitting a long running / blocker driver script, and kill the job supervisor actor before script returns and ensure @@ -366,20 +397,22 @@ class TestAsyncAPI: """ with tempfile.TemporaryDirectory() as tmp_dir: - pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir) + pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir) with open(pid_file, "r") as file: pid = int(file.read()) assert psutil.pid_exists(pid), "driver subprocess should be running" actor = job_manager._get_actor_for_job(job_id) ray.kill(actor, no_restart=True) - wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_failed, job_manager=job_manager, job_id=job_id + ) # Ensure driver subprocess gets cleaned up after job reached # termination state - wait_for_condition(check_subprocess_cleaned, pid=pid) + await async_wait_for_condition(check_subprocess_cleaned, pid=pid) - def test_stop_job_in_pending(self, job_manager): + async def test_stop_job_in_pending(self, job_manager): """ Kick off a job that is in PENDING state, stop the job and ensure @@ -389,7 +422,7 @@ class TestAsyncAPI: start_signal_actor = SignalActor.remote() with tempfile.TemporaryDirectory() as tmp_dir: - pid_file, _, job_id = _run_hanging_command( + pid_file, _, job_id = await _run_hanging_command( job_manager, tmp_dir, start_signal_actor=start_signal_actor ) assert not os.path.exists(pid_file), ( @@ -399,11 +432,11 @@ class TestAsyncAPI: assert job_manager.stop_job(job_id) is True # Send run signal to unblock run function ray.get(start_signal_actor.send.remote()) - wait_for_condition( + await async_wait_for_condition( check_job_stopped, job_manager=job_manager, job_id=job_id ) - def test_kill_job_actor_in_pending(self, job_manager): + async def test_kill_job_actor_in_pending(self, job_manager): """ Kick off a job that is in PENDING state, kill the job actor and ensure @@ -413,7 +446,7 @@ class TestAsyncAPI: start_signal_actor = SignalActor.remote() with tempfile.TemporaryDirectory() as tmp_dir: - pid_file, _, job_id = _run_hanging_command( + pid_file, _, job_id = await _run_hanging_command( job_manager, tmp_dir, start_signal_actor=start_signal_actor ) @@ -423,9 +456,11 @@ class TestAsyncAPI: actor = job_manager._get_actor_for_job(job_id) ray.kill(actor, no_restart=True) - wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_failed, job_manager=job_manager, job_id=job_id + ) - def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager): + async def test_stop_job_subprocess_cleanup_upon_stop(self, job_manager): """ Ensure driver scripts' subprocess is cleaned up properly when we stopped a running job. @@ -433,21 +468,22 @@ class TestAsyncAPI: SIGTERM first, SIGKILL after 3 seconds. """ with tempfile.TemporaryDirectory() as tmp_dir: - pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir) + pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir) with open(pid_file, "r") as file: pid = int(file.read()) assert psutil.pid_exists(pid), "driver subprocess should be running" assert job_manager.stop_job(job_id) is True - wait_for_condition( + await async_wait_for_condition( check_job_stopped, job_manager=job_manager, job_id=job_id ) # Ensure driver subprocess gets cleaned up after job reached # termination state - wait_for_condition(check_subprocess_cleaned, pid=pid) + await async_wait_for_condition(check_subprocess_cleaned, pid=pid) +@pytest.mark.asyncio class TestTailLogs: async def _tail_and_assert_logs( self, job_id, job_manager, expected_log="", num_iteration=5 @@ -460,19 +496,17 @@ class TestTailLogs: break i += 1 - @pytest.mark.asyncio async def test_unknown_job(self, job_manager): with pytest.raises(RuntimeError, match="Job 'unknown' does not exist."): async for _ in job_manager.tail_job_logs("unknown"): pass - @pytest.mark.asyncio async def test_successful_job(self, job_manager): """Test tailing logs for a PENDING -> RUNNING -> SUCCESSFUL job.""" start_signal_actor = SignalActor.remote() with tempfile.TemporaryDirectory() as tmp_dir: - _, tmp_file, job_id = _run_hanging_command( + _, tmp_file, job_id = await _run_hanging_command( job_manager, tmp_dir, start_signal_actor=start_signal_actor ) @@ -495,15 +529,14 @@ class TestTailLogs: assert all(s == "Waiting..." for s in lines.strip().split("\n")) print(lines, end="") - wait_for_condition( + await async_wait_for_condition( check_job_succeeded, job_manager=job_manager, job_id=job_id ) - @pytest.mark.asyncio async def test_failed_job(self, job_manager): """Test tailing logs for a job that unexpectedly exits.""" with tempfile.TemporaryDirectory() as tmp_dir: - pid_file, _, job_id = _run_hanging_command(job_manager, tmp_dir) + pid_file, _, job_id = await _run_hanging_command(job_manager, tmp_dir) await self._tail_and_assert_logs( job_id, job_manager, expected_log="Waiting...", num_iteration=5 @@ -517,13 +550,14 @@ class TestTailLogs: assert all(s == "Waiting..." for s in lines.strip().split("\n")) print(lines, end="") - wait_for_condition(check_job_failed, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_failed, job_manager=job_manager, job_id=job_id + ) - @pytest.mark.asyncio async def test_stopped_job(self, job_manager): """Test tailing logs for a job that unexpectedly exits.""" with tempfile.TemporaryDirectory() as tmp_dir: - _, _, job_id = _run_hanging_command(job_manager, tmp_dir) + _, _, job_id = await _run_hanging_command(job_manager, tmp_dir) await self._tail_and_assert_logs( job_id, job_manager, expected_log="Waiting...", num_iteration=5 @@ -536,12 +570,13 @@ class TestTailLogs: assert all(s == "Waiting..." for s in lines.strip().split("\n")) print(lines, end="") - wait_for_condition( + await async_wait_for_condition( check_job_stopped, job_manager=job_manager, job_id=job_id ) -def test_logs_streaming(job_manager): +@pytest.mark.asyncio +async def test_logs_streaming(job_manager): """Test that logs are streamed during the job, not just at the end.""" stream_logs_script = """ @@ -554,12 +589,15 @@ while True: stream_logs_cmd = f'python -c "{stream_logs_script}"' job_id = job_manager.submit_job(entrypoint=stream_logs_cmd) - wait_for_condition(lambda: "STREAMED" in job_manager.get_job_logs(job_id)) + await async_wait_for_condition( + lambda: "STREAMED" in job_manager.get_job_logs(job_id) + ) job_manager.stop_job(job_id) -def test_bootstrap_address(job_manager, monkeypatch): +@pytest.mark.asyncio +async def test_bootstrap_address(job_manager, monkeypatch): """Ensure we always use bootstrap address in job manager even though ray cluster might be started with http://ip:{dashboard_port} from previous runs. @@ -574,7 +612,9 @@ def test_bootstrap_address(job_manager, monkeypatch): job_id = job_manager.submit_job(entrypoint=print_ray_address_cmd) - wait_for_condition(check_job_succeeded, job_manager=job_manager, job_id=job_id) + await async_wait_for_condition( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) assert "SUCCESS!" in job_manager.get_job_logs(job_id) diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index b5807c306..6ae3ed5cb 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -377,6 +377,34 @@ def wait_for_condition( raise RuntimeError(message) +async def async_wait_for_condition( + condition_predictor, timeout=10, retry_interval_ms=100, **kwargs: Any +): + """Wait until a condition is met or time out with an exception. + + Args: + condition_predictor: A function that predicts the condition. + timeout: Maximum timeout in seconds. + retry_interval_ms: Retry interval in milliseconds. + + Raises: + RuntimeError: If the condition is not met before the timeout expires. + """ + start = time.time() + last_ex = None + while time.time() - start <= timeout: + try: + if condition_predictor(**kwargs): + return + except Exception as ex: + last_ex = ex + await asyncio.sleep(retry_interval_ms / 1000.0) + message = "The condition wasn't met before the timeout expired." + if last_ex is not None: + message += f" Last exception: {last_ex}" + raise RuntimeError(message) + + def wait_until_succeeded_without_exception( func, exceptions, *args, timeout_ms=1000, retry_interval_ms=100, raise_last_ex=False ):