Convert job_manager to be async (#27123)

Updates jobs api
Updates snapshot api
Updates state api

Increases jobs api version to 2

Signed-off-by: Alan Guo aguo@anyscale.com

Why are these changes needed?
follow-up for #25902 (comment)
This commit is contained in:
Alan Guo 2022-08-05 19:33:49 -07:00 committed by GitHub
parent a82af8602c
commit 326b5bd1ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 211 additions and 159 deletions

View file

@ -1,3 +1,4 @@
import asyncio
import pickle import pickle
import time import time
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
@ -6,12 +7,10 @@ from pathlib import Path
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from ray._private import ray_constants from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray._private.runtime_env.packaging import parse_uri from ray._private.runtime_env.packaging import parse_uri
from ray.experimental.internal_kv import ( from ray.experimental.internal_kv import (
_internal_kv_get,
_internal_kv_initialized, _internal_kv_initialized,
_internal_kv_list,
_internal_kv_put,
) )
# NOTE(edoakes): these constants should be considered a public API because # NOTE(edoakes): these constants should be considered a public API because
@ -97,19 +96,21 @@ class JobInfoStorageClient:
JOB_DATA_KEY_PREFIX = f"{ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX}job_info_" JOB_DATA_KEY_PREFIX = f"{ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX}job_info_"
JOB_DATA_KEY = f"{JOB_DATA_KEY_PREFIX}{{job_id}}" JOB_DATA_KEY = f"{JOB_DATA_KEY_PREFIX}{{job_id}}"
def __init__(self): def __init__(self, gcs_aio_client: GcsAioClient):
self._gcs_aio_client = gcs_aio_client
assert _internal_kv_initialized() assert _internal_kv_initialized()
def put_info(self, job_id: str, data: JobInfo): async def put_info(self, job_id: str, data: JobInfo):
_internal_kv_put( await self._gcs_aio_client.internal_kv_put(
self.JOB_DATA_KEY.format(job_id=job_id), self.JOB_DATA_KEY.format(job_id=job_id).encode(),
pickle.dumps(data), pickle.dumps(data),
True,
namespace=ray_constants.KV_NAMESPACE_JOB, namespace=ray_constants.KV_NAMESPACE_JOB,
) )
def get_info(self, job_id: str) -> Optional[JobInfo]: async def get_info(self, job_id: str) -> Optional[JobInfo]:
pickled_info = _internal_kv_get( pickled_info = await self._gcs_aio_client.internal_kv_get(
self.JOB_DATA_KEY.format(job_id=job_id), self.JOB_DATA_KEY.format(job_id=job_id).encode(),
namespace=ray_constants.KV_NAMESPACE_JOB, namespace=ray_constants.KV_NAMESPACE_JOB,
) )
if pickled_info is None: if pickled_info is None:
@ -117,10 +118,12 @@ class JobInfoStorageClient:
else: else:
return pickle.loads(pickled_info) return pickle.loads(pickled_info)
def put_status(self, job_id: str, status: JobStatus, message: Optional[str] = None): async def put_status(
self, job_id: str, status: JobStatus, message: Optional[str] = None
):
"""Puts or updates job status. Sets end_time if status is terminal.""" """Puts or updates job status. Sets end_time if status is terminal."""
old_info = self.get_info(job_id) old_info = await self.get_info(job_id)
if old_info is not None: if old_info is not None:
if status != old_info.status and old_info.status.is_terminal(): if status != old_info.status and old_info.status.is_terminal():
@ -134,18 +137,18 @@ class JobInfoStorageClient:
if status.is_terminal(): if status.is_terminal():
new_info.end_time = int(time.time() * 1000) new_info.end_time = int(time.time() * 1000)
self.put_info(job_id, new_info) await self.put_info(job_id, new_info)
def get_status(self, job_id: str) -> Optional[JobStatus]: async def get_status(self, job_id: str) -> Optional[JobStatus]:
job_info = self.get_info(job_id) job_info = await self.get_info(job_id)
if job_info is None: if job_info is None:
return None return None
else: else:
return job_info.status return job_info.status
def get_all_jobs(self) -> Dict[str, JobInfo]: async def get_all_jobs(self) -> Dict[str, JobInfo]:
raw_job_ids_with_prefixes = _internal_kv_list( raw_job_ids_with_prefixes = await self._gcs_aio_client.internal_kv_keys(
self.JOB_DATA_KEY_PREFIX, namespace=ray_constants.KV_NAMESPACE_JOB self.JOB_DATA_KEY_PREFIX.encode(), namespace=ray_constants.KV_NAMESPACE_JOB
) )
job_ids_with_prefixes = [ job_ids_with_prefixes = [
job_id.decode() for job_id in raw_job_ids_with_prefixes job_id.decode() for job_id in raw_job_ids_with_prefixes
@ -156,7 +159,17 @@ class JobInfoStorageClient:
self.JOB_DATA_KEY_PREFIX self.JOB_DATA_KEY_PREFIX
), "Unexpected format for internal_kv key for Job submission" ), "Unexpected format for internal_kv key for Job submission"
job_ids.append(job_id_with_prefix[len(self.JOB_DATA_KEY_PREFIX) :]) job_ids.append(job_id_with_prefix[len(self.JOB_DATA_KEY_PREFIX) :])
return {job_id: self.get_info(job_id) for job_id in job_ids}
async def get_job_info(job_id: str):
job_info = await self.get_info(job_id)
return job_id, job_info
return {
job_id: job_info
for job_id, job_info in await asyncio.gather(
*[get_job_info(job_id) for job_id in job_ids]
)
}
def uri_to_http_components(package_uri: str) -> Tuple[str, str]: def uri_to_http_components(package_uri: str) -> Tuple[str, str]:

View file

@ -1,5 +1,3 @@
import asyncio
import concurrent
import dataclasses import dataclasses
import json import json
import logging import logging
@ -54,7 +52,6 @@ class JobHead(dashboard_utils.DashboardHeadModule):
self._dashboard_head = dashboard_head self._dashboard_head = dashboard_head
self._job_manager = None self._job_manager = None
self._gcs_job_info_stub = None self._gcs_job_info_stub = None
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
async def _parse_and_validate_request( async def _parse_and_validate_request(
self, req: Request, request_type: dataclass self, req: Request, request_type: dataclass
@ -95,9 +92,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
# then lets try to search for a submission with given id # then lets try to search for a submission with given id
submission_id = job_or_submission_id submission_id = job_or_submission_id
job_info = await asyncio.get_event_loop().run_in_executor( job_info = await self._job_manager.get_job_info(submission_id)
self._executor, lambda: self._job_manager.get_job_info(submission_id)
)
if job_info: if job_info:
driver = submission_job_drivers.get(submission_id) driver = submission_job_drivers.get(submission_id)
job = JobDetails( job = JobDetails(
@ -182,7 +177,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
request_submission_id = submit_request.submission_id or submit_request.job_id request_submission_id = submit_request.submission_id or submit_request.job_id
try: try:
submission_id = self._job_manager.submit_job( submission_id = await self._job_manager.submit_job(
entrypoint=submit_request.entrypoint, entrypoint=submit_request.entrypoint,
submission_id=request_submission_id, submission_id=request_submission_id,
runtime_env=submit_request.runtime_env, runtime_env=submit_request.runtime_env,
@ -257,10 +252,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def list_jobs(self, req: Request) -> Response: async def list_jobs(self, req: Request) -> Response:
driver_jobs, submission_job_drivers = await self._get_driver_jobs() driver_jobs, submission_job_drivers = await self._get_driver_jobs()
# TODO(aguo): convert _job_manager.list_jobs to an async function. submission_jobs = await self._job_manager.list_jobs()
submission_jobs = await asyncio.get_event_loop().run_in_executor(
self._executor, self._job_manager.list_jobs
)
submission_jobs = [ submission_jobs = [
JobDetails( JobDetails(
**dataclasses.asdict(job), **dataclasses.asdict(job),
@ -386,7 +378,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def run(self, server): async def run(self, server):
if not self._job_manager: if not self._job_manager:
self._job_manager = JobManager() self._job_manager = JobManager(self._dashboard_head.gcs_aio_client)
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel self._dashboard_head.aiogrpc_gcs_channel

View file

@ -13,6 +13,7 @@ from collections import deque
from typing import Any, Dict, Iterator, Optional, Tuple from typing import Any, Dict, Iterator, Optional, Tuple
import ray import ray
from ray._private.gcs_utils import GcsAioClient
import ray._private.ray_constants as ray_constants import ray._private.ray_constants as ray_constants
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
from ray.actor import ActorHandle from ray.actor import ActorHandle
@ -103,9 +104,16 @@ class JobSupervisor:
SUBPROCESS_POLL_PERIOD_S = 0.1 SUBPROCESS_POLL_PERIOD_S = 0.1
def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]): def __init__(
self,
job_id: str,
entrypoint: str,
user_metadata: Dict[str, str],
gcs_address: str,
):
self._job_id = job_id self._job_id = job_id
self._job_info_client = JobInfoStorageClient() gcs_aio_client = GcsAioClient(address=gcs_address)
self._job_info_client = JobInfoStorageClient(gcs_aio_client)
self._log_client = JobLogStorageClient() self._log_client = JobLogStorageClient()
self._driver_runtime_env = self._get_driver_runtime_env() self._driver_runtime_env = self._get_driver_runtime_env()
self._entrypoint = entrypoint self._entrypoint = entrypoint
@ -227,14 +235,14 @@ class JobSupervisor:
variables. variables.
3) Handle concurrent events of driver execution and 3) Handle concurrent events of driver execution and
""" """
curr_status = self._job_info_client.get_status(self._job_id) curr_status = await self._job_info_client.get_status(self._job_id)
assert curr_status == JobStatus.PENDING, "Run should only be called once." assert curr_status == JobStatus.PENDING, "Run should only be called once."
if _start_signal_actor: if _start_signal_actor:
# Block in PENDING state until start signal received. # Block in PENDING state until start signal received.
await _start_signal_actor.wait.remote() await _start_signal_actor.wait.remote()
self._job_info_client.put_status(self._job_id, JobStatus.RUNNING) await self._job_info_client.put_status(self._job_id, JobStatus.RUNNING)
try: try:
# Configure environment variables for the child process. These # Configure environment variables for the child process. These
@ -257,7 +265,7 @@ class JobSupervisor:
polling_task.cancel() polling_task.cancel()
# TODO (jiaodong): Improve this with SIGTERM then SIGKILL # TODO (jiaodong): Improve this with SIGTERM then SIGKILL
child_process.kill() child_process.kill()
self._job_info_client.put_status(self._job_id, JobStatus.STOPPED) await self._job_info_client.put_status(self._job_id, JobStatus.STOPPED)
else: else:
# Child process finished execution and no stop event is set # Child process finished execution and no stop event is set
# at the same time # at the same time
@ -265,7 +273,9 @@ class JobSupervisor:
[child_process_task] = finished [child_process_task] = finished
return_code = child_process_task.result() return_code = child_process_task.result()
if return_code == 0: if return_code == 0:
self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED) await self._job_info_client.put_status(
self._job_id, JobStatus.SUCCEEDED
)
else: else:
log_tail = self._log_client.get_last_n_log_lines(self._job_id) log_tail = self._log_client.get_last_n_log_lines(self._job_id)
if log_tail is not None and log_tail != "": if log_tail is not None and log_tail != "":
@ -275,7 +285,7 @@ class JobSupervisor:
) )
else: else:
message = None message = None
self._job_info_client.put_status( await self._job_info_client.put_status(
self._job_id, JobStatus.FAILED, message=message self._job_id, JobStatus.FAILED, message=message
) )
except Exception: except Exception:
@ -307,20 +317,22 @@ class JobManager:
LOG_TAIL_SLEEP_S = 1 LOG_TAIL_SLEEP_S = 1
JOB_MONITOR_LOOP_PERIOD_S = 1 JOB_MONITOR_LOOP_PERIOD_S = 1
def __init__(self): def __init__(self, gcs_aio_client: GcsAioClient):
self._job_info_client = JobInfoStorageClient() self._gcs_aio_client = gcs_aio_client
self._job_info_client = JobInfoStorageClient(gcs_aio_client)
self._gcs_address = gcs_aio_client._channel._gcs_address
self._log_client = JobLogStorageClient() self._log_client = JobLogStorageClient()
self._supervisor_actor_cls = ray.remote(JobSupervisor) self._supervisor_actor_cls = ray.remote(JobSupervisor)
self._recover_running_jobs() create_task(self._recover_running_jobs())
def _recover_running_jobs(self): async def _recover_running_jobs(self):
"""Recovers all running jobs from the status client. """Recovers all running jobs from the status client.
For each job, we will spawn a coroutine to monitor it. For each job, we will spawn a coroutine to monitor it.
Each will be added to self._running_jobs and reconciled. Each will be added to self._running_jobs and reconciled.
""" """
all_jobs = self._job_info_client.get_all_jobs() all_jobs = await self._job_info_client.get_all_jobs()
for job_id, job_info in all_jobs.items(): for job_id, job_info in all_jobs.items():
if not job_info.status.is_terminal(): if not job_info.status.is_terminal():
create_task(self._monitor_job(job_id)) create_task(self._monitor_job(job_id))
@ -345,7 +357,7 @@ class JobManager:
if job_supervisor is None: if job_supervisor is None:
logger.error(f"Failed to get job supervisor for job {job_id}.") logger.error(f"Failed to get job supervisor for job {job_id}.")
self._job_info_client.put_status( await self._job_info_client.put_status(
job_id, job_id,
JobStatus.FAILED, JobStatus.FAILED,
message="Unexpected error occurred: Failed to get job supervisor.", message="Unexpected error occurred: Failed to get job supervisor.",
@ -358,13 +370,14 @@ class JobManager:
await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S) await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S)
except Exception as e: except Exception as e:
is_alive = False is_alive = False
if self._job_info_client.get_status(job_id).is_terminal(): job_status = await self._job_info_client.get_status(job_id)
if job_status.is_terminal():
# If the job is already in a terminal state, then the actor # If the job is already in a terminal state, then the actor
# exiting is expected. # exiting is expected.
pass pass
elif isinstance(e, RuntimeEnvSetupError): elif isinstance(e, RuntimeEnvSetupError):
logger.info(f"Failed to set up runtime_env for job {job_id}.") logger.info(f"Failed to set up runtime_env for job {job_id}.")
self._job_info_client.put_status( await self._job_info_client.put_status(
job_id, job_id,
JobStatus.FAILED, JobStatus.FAILED,
message=f"runtime_env setup failed: {e}", message=f"runtime_env setup failed: {e}",
@ -373,7 +386,7 @@ class JobManager:
logger.warning( logger.warning(
f"Job supervisor for job {job_id} failed unexpectedly: {e}." f"Job supervisor for job {job_id} failed unexpectedly: {e}."
) )
self._job_info_client.put_status( await self._job_info_client.put_status(
job_id, job_id,
JobStatus.FAILED, JobStatus.FAILED,
message=f"Unexpected error occurred: {e}", message=f"Unexpected error occurred: {e}",
@ -413,7 +426,6 @@ class JobManager:
def _get_supervisor_runtime_env( def _get_supervisor_runtime_env(
self, user_runtime_env: Dict[str, Any] self, user_runtime_env: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Configure and return the runtime_env for the supervisor actor.""" """Configure and return the runtime_env for the supervisor actor."""
# Make a copy to avoid mutating passed runtime_env. # Make a copy to avoid mutating passed runtime_env.
@ -434,7 +446,7 @@ class JobManager:
runtime_env["env_vars"] = env_vars runtime_env["env_vars"] = env_vars
return runtime_env return runtime_env
def submit_job( async def submit_job(
self, self,
*, *,
entrypoint: str, entrypoint: str,
@ -473,7 +485,7 @@ class JobManager:
""" """
if submission_id is None: if submission_id is None:
submission_id = generate_job_id() submission_id = generate_job_id()
elif self._job_info_client.get_status(submission_id) is not None: elif await self._job_info_client.get_status(submission_id) is not None:
raise RuntimeError(f"Job {submission_id} already exists.") raise RuntimeError(f"Job {submission_id} already exists.")
logger.info(f"Starting job with submission_id: {submission_id}") logger.info(f"Starting job with submission_id: {submission_id}")
@ -484,7 +496,7 @@ class JobManager:
metadata=metadata, metadata=metadata,
runtime_env=runtime_env, runtime_env=runtime_env,
) )
self._job_info_client.put_info(submission_id, job_info) await self._job_info_client.put_info(submission_id, job_info)
# Wait for the actor to start up asynchronously so this call always # Wait for the actor to start up asynchronously so this call always
# returns immediately and we can catch errors with the actor starting # returns immediately and we can catch errors with the actor starting
@ -500,14 +512,14 @@ class JobManager:
self._get_current_node_resource_key(): 0.001, self._get_current_node_resource_key(): 0.001,
}, },
runtime_env=self._get_supervisor_runtime_env(runtime_env), runtime_env=self._get_supervisor_runtime_env(runtime_env),
).remote(submission_id, entrypoint, metadata or {}) ).remote(submission_id, entrypoint, metadata or {}, self._gcs_address)
supervisor.run.remote(_start_signal_actor=_start_signal_actor) supervisor.run.remote(_start_signal_actor=_start_signal_actor)
# Monitor the job in the background so we can detect errors without # Monitor the job in the background so we can detect errors without
# requiring a client to poll. # requiring a client to poll.
create_task(self._monitor_job(submission_id, job_supervisor=supervisor)) create_task(self._monitor_job(submission_id, job_supervisor=supervisor))
except Exception as e: except Exception as e:
self._job_info_client.put_status( await self._job_info_client.put_status(
submission_id, submission_id,
JobStatus.FAILED, JobStatus.FAILED,
message=f"Failed to start job supervisor: {e}.", message=f"Failed to start job supervisor: {e}.",
@ -529,17 +541,17 @@ class JobManager:
else: else:
return False return False
def get_job_status(self, job_id: str) -> Optional[JobStatus]: async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
"""Get latest status of a job.""" """Get latest status of a job."""
return self._job_info_client.get_status(job_id) return await self._job_info_client.get_status(job_id)
def get_job_info(self, job_id: str) -> Optional[JobInfo]: async def get_job_info(self, job_id: str) -> Optional[JobInfo]:
"""Get latest info of a job.""" """Get latest info of a job."""
return self._job_info_client.get_info(job_id) return await self._job_info_client.get_info(job_id)
def list_jobs(self) -> Dict[str, JobInfo]: async def list_jobs(self) -> Dict[str, JobInfo]:
"""Get info for all jobs.""" """Get info for all jobs."""
return self._job_info_client.get_all_jobs() return await self._job_info_client.get_all_jobs()
def get_job_logs(self, job_id: str) -> str: def get_job_logs(self, job_id: str) -> str:
"""Get all logs produced by a job.""" """Get all logs produced by a job."""
@ -547,13 +559,13 @@ class JobManager:
async def tail_job_logs(self, job_id: str) -> Iterator[str]: async def tail_job_logs(self, job_id: str) -> Iterator[str]:
"""Return an iterator following the logs of a job.""" """Return an iterator following the logs of a job."""
if self.get_job_status(job_id) is None: if await self.get_job_status(job_id) is None:
raise RuntimeError(f"Job '{job_id}' does not exist.") raise RuntimeError(f"Job '{job_id}' does not exist.")
for line in self._log_client.tail_logs(job_id): for line in self._log_client.tail_logs(job_id):
if line is None: if line is None:
# Return if the job has exited and there are no new log lines. # Return if the job has exited and there are no new log lines.
status = self.get_job_status(job_id) status = await self.get_job_status(job_id)
if status not in {JobStatus.PENDING, JobStatus.RUNNING}: if status not in {JobStatus.PENDING, JobStatus.RUNNING}:
return return

View file

@ -10,8 +10,13 @@ import psutil
import pytest import pytest
import ray import ray
from ray._private.gcs_utils import GcsAioClient
from ray._private.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE from ray._private.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE
from ray._private.test_utils import SignalActor, async_wait_for_condition from ray._private.test_utils import (
SignalActor,
async_wait_for_condition,
async_wait_for_condition_async_predicate,
)
from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JOB_NAME_METADATA_KEY from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JOB_NAME_METADATA_KEY
from ray.dashboard.modules.job.job_manager import JobManager, generate_job_id from ray.dashboard.modules.job.job_manager import JobManager, generate_job_id
from ray.job_submission import JobStatus from ray.job_submission import JobStatus
@ -29,8 +34,11 @@ TEST_NAMESPACE = "jobs_test_namespace"
async def test_submit_no_ray_address(call_ray_start): # noqa: F811 async def test_submit_no_ray_address(call_ray_start): # noqa: F811
"""Test that a job script with an unspecified Ray address works.""" """Test that a job script with an unspecified Ray address works."""
ray.init(address=call_ray_start) address_info = ray.init(address=call_ray_start)
job_manager = JobManager() gcs_aio_client = GcsAioClient(
address=address_info["gcs_address"], nums_reconnect_retry=0
)
job_manager = JobManager(gcs_aio_client)
init_ray_no_address_script = """ init_ray_no_address_script = """
import ray import ray
@ -46,11 +54,11 @@ assert ray.cluster_resources().get('TestResourceKey') == 123
# The job script should work even if RAY_ADDRESS is not set on the cluster. # The job script should work even if RAY_ADDRESS is not set on the cluster.
os.environ.pop(RAY_ADDRESS_ENVIRONMENT_VARIABLE, None) os.environ.pop(RAY_ADDRESS_ENVIRONMENT_VARIABLE, None)
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=f"""python -c "{init_ray_no_address_script}" """ entrypoint=f"""python -c "{init_ray_no_address_script}" """
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
@ -68,9 +76,14 @@ def shared_ray_instance():
os.environ[RAY_ADDRESS_ENVIRONMENT_VARIABLE] = old_ray_address os.environ[RAY_ADDRESS_ENVIRONMENT_VARIABLE] = old_ray_address
@pytest.mark.asyncio
@pytest.fixture @pytest.fixture
def job_manager(shared_ray_instance): async def job_manager(shared_ray_instance):
yield JobManager() address_info = shared_ray_instance
gcs_aio_client = GcsAioClient(
address=address_info["gcs_address"], nums_reconnect_retry=0
)
yield JobManager(gcs_aio_client)
def _driver_script_path(file_name: str) -> str: def _driver_script_path(file_name: str) -> str:
@ -90,11 +103,11 @@ async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
"do echo 'Waiting...' && sleep 1; " "do echo 'Waiting...' && sleep 1; "
"done" "done"
) )
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor
) )
status = job_manager.get_job_status(job_id) status = await job_manager.get_job_status(job_id)
if start_signal_actor: if start_signal_actor:
for _ in range(10): for _ in range(10):
assert status == JobStatus.PENDING assert status == JobStatus.PENDING
@ -102,7 +115,7 @@ async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
assert logs == "" assert logs == ""
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
else: else:
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_running, job_manager=job_manager, job_id=job_id check_job_running, job_manager=job_manager, job_id=job_id
) )
await async_wait_for_condition( await async_wait_for_condition(
@ -112,8 +125,8 @@ async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
return pid_file, tmp_file, job_id return pid_file, tmp_file, job_id
def check_job_succeeded(job_manager, job_id): async def check_job_succeeded(job_manager, job_id):
data = job_manager.get_job_info(job_id) data = await job_manager.get_job_info(job_id)
status = data.status status = data.status
if status == JobStatus.FAILED: if status == JobStatus.FAILED:
raise RuntimeError(f"Job failed! {data.message}") raise RuntimeError(f"Job failed! {data.message}")
@ -121,20 +134,20 @@ def check_job_succeeded(job_manager, job_id):
return status == JobStatus.SUCCEEDED return status == JobStatus.SUCCEEDED
def check_job_failed(job_manager, job_id): async def check_job_failed(job_manager, job_id):
status = job_manager.get_job_status(job_id) status = await job_manager.get_job_status(job_id)
assert status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED} assert status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED}
return status == JobStatus.FAILED return status == JobStatus.FAILED
def check_job_stopped(job_manager, job_id): async def check_job_stopped(job_manager, job_id):
status = job_manager.get_job_status(job_id) status = await job_manager.get_job_status(job_id)
assert status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED} assert status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED}
return status == JobStatus.STOPPED return status == JobStatus.STOPPED
def check_job_running(job_manager, job_id): async def check_job_running(job_manager, job_id):
status = job_manager.get_job_status(job_id) status = await job_manager.get_job_status(job_id)
assert status in {JobStatus.PENDING, JobStatus.RUNNING} assert status in {JobStatus.PENDING, JobStatus.RUNNING}
return status == JobStatus.RUNNING return status == JobStatus.RUNNING
@ -158,29 +171,30 @@ def test_generate_job_id():
# NOTE(architkulkarni): This test must be run first in order for the job # NOTE(architkulkarni): This test must be run first in order for the job
# submission history of the shared Ray runtime to be empty. # submission history of the shared Ray runtime to be empty.
def test_list_jobs_empty(job_manager: JobManager): @pytest.mark.asyncio
assert job_manager.list_jobs() == dict() async def test_list_jobs_empty(job_manager: JobManager):
assert await job_manager.list_jobs() == dict()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_jobs(job_manager: JobManager): async def test_list_jobs(job_manager: JobManager):
job_manager.submit_job(entrypoint="echo hi", submission_id="1") await job_manager.submit_job(entrypoint="echo hi", submission_id="1")
runtime_env = {"env_vars": {"TEST": "123"}} runtime_env = {"env_vars": {"TEST": "123"}}
metadata = {"foo": "bar"} metadata = {"foo": "bar"}
job_manager.submit_job( await job_manager.submit_job(
entrypoint="echo hello", entrypoint="echo hello",
submission_id="2", submission_id="2",
runtime_env=runtime_env, runtime_env=runtime_env,
metadata=metadata, metadata=metadata,
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="1" check_job_succeeded, job_manager=job_manager, job_id="1"
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="2" check_job_succeeded, job_manager=job_manager, job_id="2"
) )
jobs_info = job_manager.list_jobs() jobs_info = await job_manager.list_jobs()
assert "1" in jobs_info assert "1" in jobs_info
assert jobs_info["1"].status == JobStatus.SUCCEEDED assert jobs_info["1"].status == JobStatus.SUCCEEDED
@ -196,43 +210,45 @@ async def test_list_jobs(job_manager: JobManager):
async def test_pass_job_id(job_manager): async def test_pass_job_id(job_manager):
submission_id = "my_custom_id" submission_id = "my_custom_id"
returned_id = job_manager.submit_job( returned_id = await job_manager.submit_job(
entrypoint="echo hello", submission_id=submission_id entrypoint="echo hello", submission_id=submission_id
) )
assert returned_id == submission_id assert returned_id == submission_id
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=submission_id check_job_succeeded, job_manager=job_manager, job_id=submission_id
) )
# Check that the same job_id is rejected. # Check that the same job_id is rejected.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
job_manager.submit_job(entrypoint="echo hello", submission_id=submission_id) await job_manager.submit_job(
entrypoint="echo hello", submission_id=submission_id
)
@pytest.mark.asyncio @pytest.mark.asyncio
class TestShellScriptExecution: class TestShellScriptExecution:
async def test_submit_basic_echo(self, job_manager): async def test_submit_basic_echo(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo hello") job_id = await job_manager.submit_job(entrypoint="echo hello")
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert job_manager.get_job_logs(job_id) == "hello\n" assert job_manager.get_job_logs(job_id) == "hello\n"
async def test_submit_stderr(self, job_manager): async def test_submit_stderr(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo error 1>&2") job_id = await job_manager.submit_job(entrypoint="echo error 1>&2")
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert job_manager.get_job_logs(job_id) == "error\n" assert job_manager.get_job_logs(job_id) == "error\n"
async def test_submit_ls_grep(self, job_manager): async def test_submit_ls_grep(self, job_manager):
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py" grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
job_id = job_manager.submit_job(entrypoint=grep_cmd) job_id = await job_manager.submit_job(entrypoint=grep_cmd)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n" assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
@ -246,10 +262,10 @@ class TestShellScriptExecution:
4) Empty logs 4) Empty logs
""" """
run_cmd = f"python {_driver_script_path('script_with_exception.py')}" run_cmd = f"python {_driver_script_path('script_with_exception.py')}"
job_id = job_manager.submit_job(entrypoint=run_cmd) job_id = await job_manager.submit_job(entrypoint=run_cmd)
def cleaned_up(): async def cleaned_up():
data = job_manager.get_job_info(job_id) data = await job_manager.get_job_info(job_id)
if data.status != JobStatus.FAILED: if data.status != JobStatus.FAILED:
return False return False
if "Exception: Script failed with exception !" not in data.message: if "Exception: Script failed with exception !" not in data.message:
@ -257,15 +273,15 @@ class TestShellScriptExecution:
return job_manager._get_actor_for_job(job_id) is None return job_manager._get_actor_for_job(job_id) is None
await async_wait_for_condition(cleaned_up) await async_wait_for_condition_async_predicate(cleaned_up)
async def test_submit_with_s3_runtime_env(self, job_manager): async def test_submit_with_s3_runtime_env(self, job_manager):
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint="python script.py", entrypoint="python script.py",
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"}, runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert ( assert (
@ -278,11 +294,11 @@ class TestShellScriptExecution:
"https://runtime-env-test.s3.amazonaws.com/script_runtime_env.zip", "https://runtime-env-test.s3.amazonaws.com/script_runtime_env.zip",
filename=f.name, filename=f.name,
) )
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint="python script.py", entrypoint="python script.py",
runtime_env={"working_dir": "file://" + filename}, runtime_env={"working_dir": "file://" + filename},
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert ( assert (
@ -297,26 +313,26 @@ class TestRuntimeEnv:
"""Test we can pass env vars in the subprocess that executes job's """Test we can pass env vars in the subprocess that executes job's
driver script. driver script.
""" """
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR", entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}}, runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert job_manager.get_job_logs(job_id) == "233\n" assert job_manager.get_job_logs(job_id) == "233\n"
async def test_multiple_runtime_envs(self, job_manager): async def test_multiple_runtime_envs(self, job_manager):
# Test that you can run two jobs in different envs without conflict. # Test that you can run two jobs in different envs without conflict.
job_id_1 = job_manager.submit_job( job_id_1 = await job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}", entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={ runtime_env={
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"} "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
}, },
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id_1 check_job_succeeded, job_manager=job_manager, job_id=job_id_1
) )
logs = job_manager.get_job_logs(job_id_1) logs = job_manager.get_job_logs(job_id_1)
@ -324,14 +340,14 @@ class TestRuntimeEnv:
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs "{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs
) # noqa: E501 ) # noqa: E501
job_id_2 = job_manager.submit_job( job_id_2 = await job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}", entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={ runtime_env={
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"} "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"}
}, },
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id_2 check_job_succeeded, job_manager=job_manager, job_id=job_id_2
) )
logs = job_manager.get_job_logs(job_id_2) logs = job_manager.get_job_logs(job_id_2)
@ -343,14 +359,14 @@ class TestRuntimeEnv:
"""Ensure we got error message from worker.py and job logs """Ensure we got error message from worker.py and job logs
if user provided runtime_env in both driver script and submit() if user provided runtime_env in both driver script and submit()
""" """
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=f"python {_driver_script_path('override_env_var.py')}", entrypoint=f"python {_driver_script_path('override_env_var.py')}",
runtime_env={ runtime_env={
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"} "env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
}, },
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
logs = job_manager.get_job_logs(job_id) logs = job_manager.get_job_logs(job_id)
@ -365,11 +381,11 @@ class TestRuntimeEnv:
runtime_env. runtime_env.
""" """
run_cmd = f"python {_driver_script_path('override_env_var.py')}" run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"} entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}
) )
data = job_manager.get_job_info(job_id) data = await job_manager.get_job_info(job_id)
assert data.status == JobStatus.FAILED assert data.status == JobStatus.FAILED
assert "path_not_exist is not a valid URI" in data.message assert "path_not_exist is not a valid URI" in data.message
@ -378,15 +394,15 @@ class TestRuntimeEnv:
runtime_env that fails to be set up. runtime_env that fails to be set up.
""" """
run_cmd = f"python {_driver_script_path('override_env_var.py')}" run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"} entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id check_job_failed, job_manager=job_manager, job_id=job_id
) )
data = job_manager.get_job_info(job_id) data = await job_manager.get_job_info(job_id)
assert "runtime_env setup failed" in data.message assert "runtime_env setup failed" in data.message
async def test_pass_metadata(self, job_manager): async def test_pass_metadata(self, job_manager):
@ -403,9 +419,9 @@ class TestRuntimeEnv:
) )
# Check that we default to only the job ID and job name. # Check that we default to only the job ID and job name.
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd) job_id = await job_manager.submit_job(entrypoint=print_metadata_cmd)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert dict_to_str( assert dict_to_str(
@ -413,11 +429,11 @@ class TestRuntimeEnv:
) in job_manager.get_job_logs(job_id) ) in job_manager.get_job_logs(job_id)
# Check that we can pass custom metadata. # Check that we can pass custom metadata.
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"} entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert ( assert (
@ -433,12 +449,12 @@ class TestRuntimeEnv:
) )
# Check that we can override job name. # Check that we can override job name.
job_id = job_manager.submit_job( job_id = await job_manager.submit_job(
entrypoint=print_metadata_cmd, entrypoint=print_metadata_cmd,
metadata={JOB_NAME_METADATA_KEY: "custom_name"}, metadata={JOB_NAME_METADATA_KEY: "custom_name"},
) )
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert dict_to_str( assert dict_to_str(
@ -459,9 +475,11 @@ class TestRuntimeEnv:
""" """
run_cmd = f"python {_driver_script_path('check_cuda_devices.py')}" run_cmd = f"python {_driver_script_path('check_cuda_devices.py')}"
runtime_env = {"env_vars": env_vars} runtime_env = {"env_vars": env_vars}
job_id = job_manager.submit_job(entrypoint=run_cmd, runtime_env=runtime_env) job_id = await job_manager.submit_job(
entrypoint=run_cmd, runtime_env=runtime_env
)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
@ -481,7 +499,7 @@ class TestAsyncAPI:
with open(tmp_file, "w") as f: with open(tmp_file, "w") as f:
print("hello", file=f) print("hello", file=f)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
# Ensure driver subprocess gets cleaned up after job reached # Ensure driver subprocess gets cleaned up after job reached
@ -493,7 +511,7 @@ class TestAsyncAPI:
_, _, job_id = await _run_hanging_command(job_manager, tmp_dir) _, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
assert job_manager.stop_job(job_id) is True assert job_manager.stop_job(job_id) is True
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id check_job_stopped, job_manager=job_manager, job_id=job_id
) )
# Assert re-stopping a stopped job also returns False # Assert re-stopping a stopped job also returns False
@ -520,7 +538,7 @@ class TestAsyncAPI:
actor = job_manager._get_actor_for_job(job_id) actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True) ray.kill(actor, no_restart=True)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id check_job_failed, job_manager=job_manager, job_id=job_id
) )
@ -548,7 +566,7 @@ class TestAsyncAPI:
assert job_manager.stop_job(job_id) is True assert job_manager.stop_job(job_id) is True
# Send run signal to unblock run function # Send run signal to unblock run function
ray.get(start_signal_actor.send.remote()) ray.get(start_signal_actor.send.remote())
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id check_job_stopped, job_manager=job_manager, job_id=job_id
) )
@ -572,7 +590,7 @@ class TestAsyncAPI:
actor = job_manager._get_actor_for_job(job_id) actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True) ray.kill(actor, no_restart=True)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id check_job_failed, job_manager=job_manager, job_id=job_id
) )
@ -590,7 +608,7 @@ class TestAsyncAPI:
assert psutil.pid_exists(pid), "driver subprocess should be running" assert psutil.pid_exists(pid), "driver subprocess should be running"
assert job_manager.stop_job(job_id) is True assert job_manager.stop_job(job_id) is True
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id check_job_stopped, job_manager=job_manager, job_id=job_id
) )
@ -628,7 +646,8 @@ class TestTailLogs:
# TODO(edoakes): check we get no logs before actor starts (not sure # TODO(edoakes): check we get no logs before actor starts (not sure
# how to timeout the iterator call). # how to timeout the iterator call).
assert job_manager.get_job_status(job_id) == JobStatus.PENDING job_status = await job_manager.get_job_status(job_id)
assert job_status == JobStatus.PENDING
# Signal job to start. # Signal job to start.
ray.get(start_signal_actor.send.remote()) ray.get(start_signal_actor.send.remote())
@ -645,7 +664,7 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n")) assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="") print(lines, end="")
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
@ -666,7 +685,7 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n")) assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="") print(lines, end="")
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id check_job_failed, job_manager=job_manager, job_id=job_id
) )
@ -686,7 +705,7 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n")) assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="") print(lines, end="")
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id check_job_stopped, job_manager=job_manager, job_id=job_id
) )
@ -704,7 +723,7 @@ while True:
stream_logs_cmd = f'python -c "{stream_logs_script}"' stream_logs_cmd = f'python -c "{stream_logs_script}"'
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd) job_id = await job_manager.submit_job(entrypoint=stream_logs_cmd)
await async_wait_for_condition( await async_wait_for_condition(
lambda: "STREAMED" in job_manager.get_job_logs(job_id) lambda: "STREAMED" in job_manager.get_job_logs(job_id)
) )
@ -726,9 +745,9 @@ async def test_bootstrap_address(job_manager, monkeypatch):
'python -c"' "import os;" "import ray;" "ray.init();" "print('SUCCESS!');" '"' 'python -c"' "import os;" "import ray;" "ray.init();" "print('SUCCESS!');" '"'
) )
job_id = job_manager.submit_job(entrypoint=print_ray_address_cmd) job_id = await job_manager.submit_job(entrypoint=print_ray_address_cmd)
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
assert "SUCCESS!" in job_manager.get_job_logs(job_id) assert "SUCCESS!" in job_manager.get_job_logs(job_id)
@ -751,8 +770,8 @@ async def test_job_runs_with_no_resources_available(job_manager):
# Check that the job starts up properly even with no CPUs available. # Check that the job starts up properly even with no CPUs available.
# The job won't exit until it has a CPU available because it waits for # The job won't exit until it has a CPU available because it waits for
# a task. # a task.
job_id = job_manager.submit_job(entrypoint=f"python {script_path}") job_id = await job_manager.submit_job(entrypoint=f"python {script_path}")
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_running, job_manager=job_manager, job_id=job_id check_job_running, job_manager=job_manager, job_id=job_id
) )
await async_wait_for_condition( await async_wait_for_condition(
@ -763,7 +782,7 @@ async def test_job_runs_with_no_resources_available(job_manager):
ray.get(hang_signal_actor.send.remote()) ray.get(hang_signal_actor.send.remote())
# Check the job succeeds now that resources are available. # Check the job succeeds now that resources are available.
await async_wait_for_condition( await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id check_job_succeeded, job_manager=job_manager, job_id=job_id
) )
await async_wait_for_condition( await async_wait_for_condition(

View file

@ -88,7 +88,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._gcs_actor_info_stub = None self._gcs_actor_info_stub = None
self._dashboard_head = dashboard_head self._dashboard_head = dashboard_head
assert _internal_kv_initialized() assert _internal_kv_initialized()
self._job_info_client = JobInfoStorageClient() self._job_info_client = None
# For offloading CPU intensive work. # For offloading CPU intensive work.
self._thread_pool = concurrent.futures.ThreadPoolExecutor( self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2, thread_name_prefix="api_head" max_workers=2, thread_name_prefix="api_head"
@ -269,11 +269,11 @@ class APIHead(dashboard_utils.DashboardHeadModule):
timestamp=datetime.now().timestamp(), timestamp=datetime.now().timestamp(),
) )
def _get_job_info(self, metadata: Dict[str, str]) -> Optional[JobInfo]: async def _get_job_info(self, metadata: Dict[str, str]) -> Optional[JobInfo]:
# If a job submission ID has been added to a job, the status is # If a job submission ID has been added to a job, the status is
# guaranteed to be returned. # guaranteed to be returned.
job_submission_id = metadata.get(JOB_ID_METADATA_KEY) job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
return self._job_info_client.get_info(job_submission_id) return await self._job_info_client.get_info(job_submission_id)
async def get_job_info(self): async def get_job_info(self):
"""Return info for each job. Here a job is a Ray driver.""" """Return info for each job. Here a job is a Ray driver."""
@ -291,7 +291,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
job_table_entry.config.runtime_env_info.serialized_runtime_env job_table_entry.config.runtime_env_info.serialized_runtime_env
), ),
} }
info = self._get_job_info(metadata) info = await self._get_job_info(metadata)
entry = { entry = {
"status": None if info is None else info.status, "status": None if info is None else info.status,
"status_message": None if info is None else info.message, "status_message": None if info is None else info.message,
@ -308,8 +308,11 @@ class APIHead(dashboard_utils.DashboardHeadModule):
"""Info for Ray job submission. Here a job can have 0 or many drivers.""" """Info for Ray job submission. Here a job can have 0 or many drivers."""
jobs = {} jobs = {}
fetched_jobs = await self._job_info_client.get_all_jobs()
for job_submission_id, job_info in self._job_info_client.get_all_jobs().items(): for (
job_submission_id,
job_info,
) in fetched_jobs.items():
if job_info is not None: if job_info is not None:
entry = { entry = {
"job_submission_id": job_submission_id, "job_submission_id": job_submission_id,
@ -428,6 +431,12 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub( self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel self._dashboard_head.aiogrpc_gcs_channel
) )
# Lazily constructed because dashboard_head's gcs_aio_client
# is lazily constructed
if not self._job_info_client:
self._job_info_client = JobInfoStorageClient(
self._dashboard_head.gcs_aio_client
)
@staticmethod @staticmethod
def is_minimal_module(): def is_minimal_module():

View file

@ -267,7 +267,7 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
@RateLimitedModule.enforce_max_concurrent_calls @RateLimitedModule.enforce_max_concurrent_calls
async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
try: try:
result = self._state_api.list_jobs(option=self._options_from_req(req)) result = await self._state_api.list_jobs(option=self._options_from_req(req))
return self._reply( return self._reply(
success=True, success=True,
error_message="", error_message="",
@ -432,7 +432,9 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
async def run(self, server): async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._state_api_data_source_client = StateDataSourceClient(gcs_channel) self._state_api_data_source_client = StateDataSourceClient(
gcs_channel, self._dashboard_head.gcs_aio_client
)
self._state_api = StateAPIManager(self._state_api_data_source_client) self._state_api = StateAPIManager(self._state_api_data_source_client)
self._log_api = LogsManager(self._state_api_data_source_client) self._log_api = LogsManager(self._state_api_data_source_client)

View file

@ -330,11 +330,11 @@ class StateAPIManager:
num_filtered=num_filtered, num_filtered=num_filtered,
) )
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse: async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
# TODO(sang): Support limit & timeout & async calls. # TODO(sang): Support limit & timeout & async calls.
try: try:
result = [] result = []
job_info = self._client.get_job_info() job_info = await self._client.get_job_info()
for job_id, data in job_info.items(): for job_id, data in job_info.items():
data = asdict(data) data = asdict(data)
data["job_id"] = job_id data["job_id"] = job_id

View file

@ -10,6 +10,7 @@ from grpc.aio._call import UnaryStreamCall
import ray import ray
import ray.dashboard.modules.log.log_consts as log_consts import ray.dashboard.modules.log.log_consts as log_consts
from ray._private import ray_constants from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray.core.generated import gcs_service_pb2_grpc from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import ( from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoReply, GetAllActorInfoReply,
@ -138,12 +139,12 @@ class StateDataSourceClient:
- throw a ValueError if it cannot find the source. - throw a ValueError if it cannot find the source.
""" """
def __init__(self, gcs_channel: grpc.aio.Channel): def __init__(self, gcs_channel: grpc.aio.Channel, gcs_aio_client: GcsAioClient):
self.register_gcs_client(gcs_channel) self.register_gcs_client(gcs_channel)
self._raylet_stubs = {} self._raylet_stubs = {}
self._runtime_env_agent_stub = {} self._runtime_env_agent_stub = {}
self._log_agent_stub = {} self._log_agent_stub = {}
self._job_client = JobInfoStorageClient() self._job_client = JobInfoStorageClient(gcs_aio_client)
self._id_id_map = IdToIpMap() self._id_id_map = IdToIpMap()
def register_gcs_client(self, gcs_channel: grpc.aio.Channel): def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
@ -256,11 +257,11 @@ class StateDataSourceClient:
) )
return reply return reply
def get_job_info(self) -> Optional[Dict[str, JobInfo]]: async def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
# Cannot use @handle_grpc_network_errors because async def is not supported yet. # Cannot use @handle_grpc_network_errors because async def is not supported yet.
# TODO(sang): Support timeout & make it async # TODO(sang): Support timeout & make it async
try: try:
return self._job_client.get_all_jobs() return await self._job_client.get_all_jobs()
except grpc.aio.AioRpcError as e: except grpc.aio.AioRpcError as e:
if ( if (
e.code == grpc.StatusCode.DEADLINE_EXCEEDED e.code == grpc.StatusCode.DEADLINE_EXCEEDED

View file

@ -6,6 +6,7 @@ from typing import List, Tuple
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from ray._private.gcs_utils import GcsAioClient
import yaml import yaml
from click.testing import CliRunner from click.testing import CliRunner
@ -97,7 +98,7 @@ from ray.experimental.state.state_manager import IdToIpMap, StateDataSourceClien
from ray.job_submission import JobSubmissionClient from ray.job_submission import JobSubmissionClient
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv
if sys.version_info > (3, 7, 0): if sys.version_info >= (3, 8, 0):
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
else: else:
from asyncmock import AsyncMock from asyncmock import AsyncMock
@ -1096,7 +1097,8 @@ async def test_state_data_source_client(ray_start_cluster):
gcs_channel = ray._private.utils.init_grpc_channel( gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
) )
client = StateDataSourceClient(gcs_channel) gcs_aio_client = GcsAioClient(address=cluster.address, nums_reconnect_retry=0)
client = StateDataSourceClient(gcs_channel, gcs_aio_client)
""" """
Test actor Test actor
@ -1132,7 +1134,7 @@ async def test_state_data_source_client(ray_start_cluster):
# Entrypoint shell command to execute # Entrypoint shell command to execute
entrypoint="ls", entrypoint="ls",
) )
result = client.get_job_info() result = await client.get_job_info()
assert list(result.keys())[0] == job_id assert list(result.keys())[0] == job_id
assert isinstance(result, dict) assert isinstance(result, dict)
@ -1248,7 +1250,8 @@ async def test_state_data_source_client_limit_gcs_source(ray_start_cluster):
gcs_channel = ray._private.utils.init_grpc_channel( gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
) )
client = StateDataSourceClient(gcs_channel) gcs_aio_client = GcsAioClient(address=cluster.address, nums_reconnect_retry=0)
client = StateDataSourceClient(gcs_channel, gcs_aio_client)
""" """
Test actor Test actor
@ -1299,7 +1302,8 @@ async def test_state_data_source_client_limit_distributed_sources(ray_start_clus
gcs_channel = ray._private.utils.init_grpc_channel( gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
) )
client = StateDataSourceClient(gcs_channel) gcs_aio_client = GcsAioClient(address=cluster.address, nums_reconnect_retry=0)
client = StateDataSourceClient(gcs_channel, gcs_aio_client)
for node in ray.nodes(): for node in ray.nodes():
node_id = node["NodeID"] node_id = node["NodeID"]
ip = node["NodeManagerAddress"] ip = node["NodeManagerAddress"]