Convert job_manager to be async (#27123)

Updates jobs api
Updates snapshot api
Updates state api

Increases jobs api version to 2

Signed-off-by: Alan Guo aguo@anyscale.com

Why are these changes needed?
follow-up for #25902 (comment)
This commit is contained in:
Alan Guo 2022-08-05 19:33:49 -07:00 committed by GitHub
parent a82af8602c
commit 326b5bd1ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 211 additions and 159 deletions

View file

@ -1,3 +1,4 @@
import asyncio
import pickle
import time
from dataclasses import dataclass, replace
@ -6,12 +7,10 @@ from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray._private.runtime_env.packaging import parse_uri
from ray.experimental.internal_kv import (
_internal_kv_get,
_internal_kv_initialized,
_internal_kv_list,
_internal_kv_put,
)
# NOTE(edoakes): these constants should be considered a public API because
@ -97,19 +96,21 @@ class JobInfoStorageClient:
JOB_DATA_KEY_PREFIX = f"{ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX}job_info_"
JOB_DATA_KEY = f"{JOB_DATA_KEY_PREFIX}{{job_id}}"
def __init__(self):
def __init__(self, gcs_aio_client: GcsAioClient):
self._gcs_aio_client = gcs_aio_client
assert _internal_kv_initialized()
def put_info(self, job_id: str, data: JobInfo):
_internal_kv_put(
self.JOB_DATA_KEY.format(job_id=job_id),
async def put_info(self, job_id: str, data: JobInfo):
await self._gcs_aio_client.internal_kv_put(
self.JOB_DATA_KEY.format(job_id=job_id).encode(),
pickle.dumps(data),
True,
namespace=ray_constants.KV_NAMESPACE_JOB,
)
def get_info(self, job_id: str) -> Optional[JobInfo]:
pickled_info = _internal_kv_get(
self.JOB_DATA_KEY.format(job_id=job_id),
async def get_info(self, job_id: str) -> Optional[JobInfo]:
pickled_info = await self._gcs_aio_client.internal_kv_get(
self.JOB_DATA_KEY.format(job_id=job_id).encode(),
namespace=ray_constants.KV_NAMESPACE_JOB,
)
if pickled_info is None:
@ -117,10 +118,12 @@ class JobInfoStorageClient:
else:
return pickle.loads(pickled_info)
def put_status(self, job_id: str, status: JobStatus, message: Optional[str] = None):
async def put_status(
self, job_id: str, status: JobStatus, message: Optional[str] = None
):
"""Puts or updates job status. Sets end_time if status is terminal."""
old_info = self.get_info(job_id)
old_info = await self.get_info(job_id)
if old_info is not None:
if status != old_info.status and old_info.status.is_terminal():
@ -134,18 +137,18 @@ class JobInfoStorageClient:
if status.is_terminal():
new_info.end_time = int(time.time() * 1000)
self.put_info(job_id, new_info)
await self.put_info(job_id, new_info)
def get_status(self, job_id: str) -> Optional[JobStatus]:
job_info = self.get_info(job_id)
async def get_status(self, job_id: str) -> Optional[JobStatus]:
job_info = await self.get_info(job_id)
if job_info is None:
return None
else:
return job_info.status
def get_all_jobs(self) -> Dict[str, JobInfo]:
raw_job_ids_with_prefixes = _internal_kv_list(
self.JOB_DATA_KEY_PREFIX, namespace=ray_constants.KV_NAMESPACE_JOB
async def get_all_jobs(self) -> Dict[str, JobInfo]:
raw_job_ids_with_prefixes = await self._gcs_aio_client.internal_kv_keys(
self.JOB_DATA_KEY_PREFIX.encode(), namespace=ray_constants.KV_NAMESPACE_JOB
)
job_ids_with_prefixes = [
job_id.decode() for job_id in raw_job_ids_with_prefixes
@ -156,7 +159,17 @@ class JobInfoStorageClient:
self.JOB_DATA_KEY_PREFIX
), "Unexpected format for internal_kv key for Job submission"
job_ids.append(job_id_with_prefix[len(self.JOB_DATA_KEY_PREFIX) :])
return {job_id: self.get_info(job_id) for job_id in job_ids}
async def get_job_info(job_id: str):
job_info = await self.get_info(job_id)
return job_id, job_info
return {
job_id: job_info
for job_id, job_info in await asyncio.gather(
*[get_job_info(job_id) for job_id in job_ids]
)
}
def uri_to_http_components(package_uri: str) -> Tuple[str, str]:

View file

@ -1,5 +1,3 @@
import asyncio
import concurrent
import dataclasses
import json
import logging
@ -54,7 +52,6 @@ class JobHead(dashboard_utils.DashboardHeadModule):
self._dashboard_head = dashboard_head
self._job_manager = None
self._gcs_job_info_stub = None
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
async def _parse_and_validate_request(
self, req: Request, request_type: dataclass
@ -95,9 +92,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
# then lets try to search for a submission with given id
submission_id = job_or_submission_id
job_info = await asyncio.get_event_loop().run_in_executor(
self._executor, lambda: self._job_manager.get_job_info(submission_id)
)
job_info = await self._job_manager.get_job_info(submission_id)
if job_info:
driver = submission_job_drivers.get(submission_id)
job = JobDetails(
@ -182,7 +177,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
request_submission_id = submit_request.submission_id or submit_request.job_id
try:
submission_id = self._job_manager.submit_job(
submission_id = await self._job_manager.submit_job(
entrypoint=submit_request.entrypoint,
submission_id=request_submission_id,
runtime_env=submit_request.runtime_env,
@ -257,10 +252,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def list_jobs(self, req: Request) -> Response:
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
# TODO(aguo): convert _job_manager.list_jobs to an async function.
submission_jobs = await asyncio.get_event_loop().run_in_executor(
self._executor, self._job_manager.list_jobs
)
submission_jobs = await self._job_manager.list_jobs()
submission_jobs = [
JobDetails(
**dataclasses.asdict(job),
@ -386,7 +378,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
async def run(self, server):
if not self._job_manager:
self._job_manager = JobManager()
self._job_manager = JobManager(self._dashboard_head.gcs_aio_client)
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel

View file

@ -13,6 +13,7 @@ from collections import deque
from typing import Any, Dict, Iterator, Optional, Tuple
import ray
from ray._private.gcs_utils import GcsAioClient
import ray._private.ray_constants as ray_constants
from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR
from ray.actor import ActorHandle
@ -103,9 +104,16 @@ class JobSupervisor:
SUBPROCESS_POLL_PERIOD_S = 0.1
def __init__(self, job_id: str, entrypoint: str, user_metadata: Dict[str, str]):
def __init__(
self,
job_id: str,
entrypoint: str,
user_metadata: Dict[str, str],
gcs_address: str,
):
self._job_id = job_id
self._job_info_client = JobInfoStorageClient()
gcs_aio_client = GcsAioClient(address=gcs_address)
self._job_info_client = JobInfoStorageClient(gcs_aio_client)
self._log_client = JobLogStorageClient()
self._driver_runtime_env = self._get_driver_runtime_env()
self._entrypoint = entrypoint
@ -227,14 +235,14 @@ class JobSupervisor:
variables.
3) Handle concurrent events of driver execution and
"""
curr_status = self._job_info_client.get_status(self._job_id)
curr_status = await self._job_info_client.get_status(self._job_id)
assert curr_status == JobStatus.PENDING, "Run should only be called once."
if _start_signal_actor:
# Block in PENDING state until start signal received.
await _start_signal_actor.wait.remote()
self._job_info_client.put_status(self._job_id, JobStatus.RUNNING)
await self._job_info_client.put_status(self._job_id, JobStatus.RUNNING)
try:
# Configure environment variables for the child process. These
@ -257,7 +265,7 @@ class JobSupervisor:
polling_task.cancel()
# TODO (jiaodong): Improve this with SIGTERM then SIGKILL
child_process.kill()
self._job_info_client.put_status(self._job_id, JobStatus.STOPPED)
await self._job_info_client.put_status(self._job_id, JobStatus.STOPPED)
else:
# Child process finished execution and no stop event is set
# at the same time
@ -265,7 +273,9 @@ class JobSupervisor:
[child_process_task] = finished
return_code = child_process_task.result()
if return_code == 0:
self._job_info_client.put_status(self._job_id, JobStatus.SUCCEEDED)
await self._job_info_client.put_status(
self._job_id, JobStatus.SUCCEEDED
)
else:
log_tail = self._log_client.get_last_n_log_lines(self._job_id)
if log_tail is not None and log_tail != "":
@ -275,7 +285,7 @@ class JobSupervisor:
)
else:
message = None
self._job_info_client.put_status(
await self._job_info_client.put_status(
self._job_id, JobStatus.FAILED, message=message
)
except Exception:
@ -307,20 +317,22 @@ class JobManager:
LOG_TAIL_SLEEP_S = 1
JOB_MONITOR_LOOP_PERIOD_S = 1
def __init__(self):
self._job_info_client = JobInfoStorageClient()
def __init__(self, gcs_aio_client: GcsAioClient):
self._gcs_aio_client = gcs_aio_client
self._job_info_client = JobInfoStorageClient(gcs_aio_client)
self._gcs_address = gcs_aio_client._channel._gcs_address
self._log_client = JobLogStorageClient()
self._supervisor_actor_cls = ray.remote(JobSupervisor)
self._recover_running_jobs()
create_task(self._recover_running_jobs())
def _recover_running_jobs(self):
async def _recover_running_jobs(self):
"""Recovers all running jobs from the status client.
For each job, we will spawn a coroutine to monitor it.
Each will be added to self._running_jobs and reconciled.
"""
all_jobs = self._job_info_client.get_all_jobs()
all_jobs = await self._job_info_client.get_all_jobs()
for job_id, job_info in all_jobs.items():
if not job_info.status.is_terminal():
create_task(self._monitor_job(job_id))
@ -345,7 +357,7 @@ class JobManager:
if job_supervisor is None:
logger.error(f"Failed to get job supervisor for job {job_id}.")
self._job_info_client.put_status(
await self._job_info_client.put_status(
job_id,
JobStatus.FAILED,
message="Unexpected error occurred: Failed to get job supervisor.",
@ -358,13 +370,14 @@ class JobManager:
await asyncio.sleep(self.JOB_MONITOR_LOOP_PERIOD_S)
except Exception as e:
is_alive = False
if self._job_info_client.get_status(job_id).is_terminal():
job_status = await self._job_info_client.get_status(job_id)
if job_status.is_terminal():
# If the job is already in a terminal state, then the actor
# exiting is expected.
pass
elif isinstance(e, RuntimeEnvSetupError):
logger.info(f"Failed to set up runtime_env for job {job_id}.")
self._job_info_client.put_status(
await self._job_info_client.put_status(
job_id,
JobStatus.FAILED,
message=f"runtime_env setup failed: {e}",
@ -373,7 +386,7 @@ class JobManager:
logger.warning(
f"Job supervisor for job {job_id} failed unexpectedly: {e}."
)
self._job_info_client.put_status(
await self._job_info_client.put_status(
job_id,
JobStatus.FAILED,
message=f"Unexpected error occurred: {e}",
@ -413,7 +426,6 @@ class JobManager:
def _get_supervisor_runtime_env(
self, user_runtime_env: Dict[str, Any]
) -> Dict[str, Any]:
"""Configure and return the runtime_env for the supervisor actor."""
# Make a copy to avoid mutating passed runtime_env.
@ -434,7 +446,7 @@ class JobManager:
runtime_env["env_vars"] = env_vars
return runtime_env
def submit_job(
async def submit_job(
self,
*,
entrypoint: str,
@ -473,7 +485,7 @@ class JobManager:
"""
if submission_id is None:
submission_id = generate_job_id()
elif self._job_info_client.get_status(submission_id) is not None:
elif await self._job_info_client.get_status(submission_id) is not None:
raise RuntimeError(f"Job {submission_id} already exists.")
logger.info(f"Starting job with submission_id: {submission_id}")
@ -484,7 +496,7 @@ class JobManager:
metadata=metadata,
runtime_env=runtime_env,
)
self._job_info_client.put_info(submission_id, job_info)
await self._job_info_client.put_info(submission_id, job_info)
# Wait for the actor to start up asynchronously so this call always
# returns immediately and we can catch errors with the actor starting
@ -500,14 +512,14 @@ class JobManager:
self._get_current_node_resource_key(): 0.001,
},
runtime_env=self._get_supervisor_runtime_env(runtime_env),
).remote(submission_id, entrypoint, metadata or {})
).remote(submission_id, entrypoint, metadata or {}, self._gcs_address)
supervisor.run.remote(_start_signal_actor=_start_signal_actor)
# Monitor the job in the background so we can detect errors without
# requiring a client to poll.
create_task(self._monitor_job(submission_id, job_supervisor=supervisor))
except Exception as e:
self._job_info_client.put_status(
await self._job_info_client.put_status(
submission_id,
JobStatus.FAILED,
message=f"Failed to start job supervisor: {e}.",
@ -529,17 +541,17 @@ class JobManager:
else:
return False
def get_job_status(self, job_id: str) -> Optional[JobStatus]:
async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
"""Get latest status of a job."""
return self._job_info_client.get_status(job_id)
return await self._job_info_client.get_status(job_id)
def get_job_info(self, job_id: str) -> Optional[JobInfo]:
async def get_job_info(self, job_id: str) -> Optional[JobInfo]:
"""Get latest info of a job."""
return self._job_info_client.get_info(job_id)
return await self._job_info_client.get_info(job_id)
def list_jobs(self) -> Dict[str, JobInfo]:
async def list_jobs(self) -> Dict[str, JobInfo]:
"""Get info for all jobs."""
return self._job_info_client.get_all_jobs()
return await self._job_info_client.get_all_jobs()
def get_job_logs(self, job_id: str) -> str:
"""Get all logs produced by a job."""
@ -547,13 +559,13 @@ class JobManager:
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
"""Return an iterator following the logs of a job."""
if self.get_job_status(job_id) is None:
if await self.get_job_status(job_id) is None:
raise RuntimeError(f"Job '{job_id}' does not exist.")
for line in self._log_client.tail_logs(job_id):
if line is None:
# Return if the job has exited and there are no new log lines.
status = self.get_job_status(job_id)
status = await self.get_job_status(job_id)
if status not in {JobStatus.PENDING, JobStatus.RUNNING}:
return

View file

@ -10,8 +10,13 @@ import psutil
import pytest
import ray
from ray._private.gcs_utils import GcsAioClient
from ray._private.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE
from ray._private.test_utils import SignalActor, async_wait_for_condition
from ray._private.test_utils import (
SignalActor,
async_wait_for_condition,
async_wait_for_condition_async_predicate,
)
from ray.dashboard.modules.job.common import JOB_ID_METADATA_KEY, JOB_NAME_METADATA_KEY
from ray.dashboard.modules.job.job_manager import JobManager, generate_job_id
from ray.job_submission import JobStatus
@ -29,8 +34,11 @@ TEST_NAMESPACE = "jobs_test_namespace"
async def test_submit_no_ray_address(call_ray_start): # noqa: F811
"""Test that a job script with an unspecified Ray address works."""
ray.init(address=call_ray_start)
job_manager = JobManager()
address_info = ray.init(address=call_ray_start)
gcs_aio_client = GcsAioClient(
address=address_info["gcs_address"], nums_reconnect_retry=0
)
job_manager = JobManager(gcs_aio_client)
init_ray_no_address_script = """
import ray
@ -46,11 +54,11 @@ assert ray.cluster_resources().get('TestResourceKey') == 123
# The job script should work even if RAY_ADDRESS is not set on the cluster.
os.environ.pop(RAY_ADDRESS_ENVIRONMENT_VARIABLE, None)
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=f"""python -c "{init_ray_no_address_script}" """
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
@ -68,9 +76,14 @@ def shared_ray_instance():
os.environ[RAY_ADDRESS_ENVIRONMENT_VARIABLE] = old_ray_address
@pytest.mark.asyncio
@pytest.fixture
def job_manager(shared_ray_instance):
yield JobManager()
async def job_manager(shared_ray_instance):
address_info = shared_ray_instance
gcs_aio_client = GcsAioClient(
address=address_info["gcs_address"], nums_reconnect_retry=0
)
yield JobManager(gcs_aio_client)
def _driver_script_path(file_name: str) -> str:
@ -90,11 +103,11 @@ async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
"do echo 'Waiting...' && sleep 1; "
"done"
)
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=wait_for_file_cmd, _start_signal_actor=start_signal_actor
)
status = job_manager.get_job_status(job_id)
status = await job_manager.get_job_status(job_id)
if start_signal_actor:
for _ in range(10):
assert status == JobStatus.PENDING
@ -102,7 +115,7 @@ async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
assert logs == ""
await asyncio.sleep(0.01)
else:
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_running, job_manager=job_manager, job_id=job_id
)
await async_wait_for_condition(
@ -112,8 +125,8 @@ async def _run_hanging_command(job_manager, tmp_dir, start_signal_actor=None):
return pid_file, tmp_file, job_id
def check_job_succeeded(job_manager, job_id):
data = job_manager.get_job_info(job_id)
async def check_job_succeeded(job_manager, job_id):
data = await job_manager.get_job_info(job_id)
status = data.status
if status == JobStatus.FAILED:
raise RuntimeError(f"Job failed! {data.message}")
@ -121,20 +134,20 @@ def check_job_succeeded(job_manager, job_id):
return status == JobStatus.SUCCEEDED
def check_job_failed(job_manager, job_id):
status = job_manager.get_job_status(job_id)
async def check_job_failed(job_manager, job_id):
status = await job_manager.get_job_status(job_id)
assert status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.FAILED}
return status == JobStatus.FAILED
def check_job_stopped(job_manager, job_id):
status = job_manager.get_job_status(job_id)
async def check_job_stopped(job_manager, job_id):
status = await job_manager.get_job_status(job_id)
assert status in {JobStatus.PENDING, JobStatus.RUNNING, JobStatus.STOPPED}
return status == JobStatus.STOPPED
def check_job_running(job_manager, job_id):
status = job_manager.get_job_status(job_id)
async def check_job_running(job_manager, job_id):
status = await job_manager.get_job_status(job_id)
assert status in {JobStatus.PENDING, JobStatus.RUNNING}
return status == JobStatus.RUNNING
@ -158,29 +171,30 @@ def test_generate_job_id():
# NOTE(architkulkarni): This test must be run first in order for the job
# submission history of the shared Ray runtime to be empty.
def test_list_jobs_empty(job_manager: JobManager):
assert job_manager.list_jobs() == dict()
@pytest.mark.asyncio
async def test_list_jobs_empty(job_manager: JobManager):
assert await job_manager.list_jobs() == dict()
@pytest.mark.asyncio
async def test_list_jobs(job_manager: JobManager):
job_manager.submit_job(entrypoint="echo hi", submission_id="1")
await job_manager.submit_job(entrypoint="echo hi", submission_id="1")
runtime_env = {"env_vars": {"TEST": "123"}}
metadata = {"foo": "bar"}
job_manager.submit_job(
await job_manager.submit_job(
entrypoint="echo hello",
submission_id="2",
runtime_env=runtime_env,
metadata=metadata,
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="1"
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="2"
)
jobs_info = job_manager.list_jobs()
jobs_info = await job_manager.list_jobs()
assert "1" in jobs_info
assert jobs_info["1"].status == JobStatus.SUCCEEDED
@ -196,43 +210,45 @@ async def test_list_jobs(job_manager: JobManager):
async def test_pass_job_id(job_manager):
submission_id = "my_custom_id"
returned_id = job_manager.submit_job(
returned_id = await job_manager.submit_job(
entrypoint="echo hello", submission_id=submission_id
)
assert returned_id == submission_id
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=submission_id
)
# Check that the same job_id is rejected.
with pytest.raises(RuntimeError):
job_manager.submit_job(entrypoint="echo hello", submission_id=submission_id)
await job_manager.submit_job(
entrypoint="echo hello", submission_id=submission_id
)
@pytest.mark.asyncio
class TestShellScriptExecution:
async def test_submit_basic_echo(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo hello")
job_id = await job_manager.submit_job(entrypoint="echo hello")
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "hello\n"
async def test_submit_stderr(self, job_manager):
job_id = job_manager.submit_job(entrypoint="echo error 1>&2")
job_id = await job_manager.submit_job(entrypoint="echo error 1>&2")
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "error\n"
async def test_submit_ls_grep(self, job_manager):
grep_cmd = f"ls {os.path.dirname(__file__)} | grep test_job_manager.py"
job_id = job_manager.submit_job(entrypoint=grep_cmd)
job_id = await job_manager.submit_job(entrypoint=grep_cmd)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "test_job_manager.py\n"
@ -246,10 +262,10 @@ class TestShellScriptExecution:
4) Empty logs
"""
run_cmd = f"python {_driver_script_path('script_with_exception.py')}"
job_id = job_manager.submit_job(entrypoint=run_cmd)
job_id = await job_manager.submit_job(entrypoint=run_cmd)
def cleaned_up():
data = job_manager.get_job_info(job_id)
async def cleaned_up():
data = await job_manager.get_job_info(job_id)
if data.status != JobStatus.FAILED:
return False
if "Exception: Script failed with exception !" not in data.message:
@ -257,15 +273,15 @@ class TestShellScriptExecution:
return job_manager._get_actor_for_job(job_id) is None
await async_wait_for_condition(cleaned_up)
await async_wait_for_condition_async_predicate(cleaned_up)
async def test_submit_with_s3_runtime_env(self, job_manager):
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint="python script.py",
runtime_env={"working_dir": "s3://runtime-env-test/script_runtime_env.zip"},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert (
@ -278,11 +294,11 @@ class TestShellScriptExecution:
"https://runtime-env-test.s3.amazonaws.com/script_runtime_env.zip",
filename=f.name,
)
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint="python script.py",
runtime_env={"working_dir": "file://" + filename},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert (
@ -297,26 +313,26 @@ class TestRuntimeEnv:
"""Test we can pass env vars in the subprocess that executes job's
driver script.
"""
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint="echo $TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR",
runtime_env={"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "233"}},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert job_manager.get_job_logs(job_id) == "233\n"
async def test_multiple_runtime_envs(self, job_manager):
# Test that you can run two jobs in different envs without conflict.
job_id_1 = job_manager.submit_job(
job_id_1 = await job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id_1
)
logs = job_manager.get_job_logs(job_id_1)
@ -324,14 +340,14 @@ class TestRuntimeEnv:
"{'env_vars': {'TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR': 'JOB_1_VAR'}}" in logs
) # noqa: E501
job_id_2 = job_manager.submit_job(
job_id_2 = await job_manager.submit_job(
entrypoint=f"python {_driver_script_path('print_runtime_env.py')}",
runtime_env={
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_2_VAR"}
},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id_2
)
logs = job_manager.get_job_logs(job_id_2)
@ -343,14 +359,14 @@ class TestRuntimeEnv:
"""Ensure we got error message from worker.py and job logs
if user provided runtime_env in both driver script and submit()
"""
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=f"python {_driver_script_path('override_env_var.py')}",
runtime_env={
"env_vars": {"TEST_SUBPROCESS_JOB_CONFIG_ENV_VAR": "JOB_1_VAR"}
},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
logs = job_manager.get_job_logs(job_id)
@ -365,11 +381,11 @@ class TestRuntimeEnv:
runtime_env.
"""
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=run_cmd, runtime_env={"working_dir": "path_not_exist"}
)
data = job_manager.get_job_info(job_id)
data = await job_manager.get_job_info(job_id)
assert data.status == JobStatus.FAILED
assert "path_not_exist is not a valid URI" in data.message
@ -378,15 +394,15 @@ class TestRuntimeEnv:
runtime_env that fails to be set up.
"""
run_cmd = f"python {_driver_script_path('override_env_var.py')}"
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=run_cmd, runtime_env={"working_dir": "s3://does_not_exist.zip"}
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id
)
data = job_manager.get_job_info(job_id)
data = await job_manager.get_job_info(job_id)
assert "runtime_env setup failed" in data.message
async def test_pass_metadata(self, job_manager):
@ -403,9 +419,9 @@ class TestRuntimeEnv:
)
# Check that we default to only the job ID and job name.
job_id = job_manager.submit_job(entrypoint=print_metadata_cmd)
job_id = await job_manager.submit_job(entrypoint=print_metadata_cmd)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert dict_to_str(
@ -413,11 +429,11 @@ class TestRuntimeEnv:
) in job_manager.get_job_logs(job_id)
# Check that we can pass custom metadata.
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=print_metadata_cmd, metadata={"key1": "val1", "key2": "val2"}
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert (
@ -433,12 +449,12 @@ class TestRuntimeEnv:
)
# Check that we can override job name.
job_id = job_manager.submit_job(
job_id = await job_manager.submit_job(
entrypoint=print_metadata_cmd,
metadata={JOB_NAME_METADATA_KEY: "custom_name"},
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert dict_to_str(
@ -459,9 +475,11 @@ class TestRuntimeEnv:
"""
run_cmd = f"python {_driver_script_path('check_cuda_devices.py')}"
runtime_env = {"env_vars": env_vars}
job_id = job_manager.submit_job(entrypoint=run_cmd, runtime_env=runtime_env)
job_id = await job_manager.submit_job(
entrypoint=run_cmd, runtime_env=runtime_env
)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
@ -481,7 +499,7 @@ class TestAsyncAPI:
with open(tmp_file, "w") as f:
print("hello", file=f)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
# Ensure driver subprocess gets cleaned up after job reached
@ -493,7 +511,7 @@ class TestAsyncAPI:
_, _, job_id = await _run_hanging_command(job_manager, tmp_dir)
assert job_manager.stop_job(job_id) is True
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
# Assert re-stopping a stopped job also returns False
@ -520,7 +538,7 @@ class TestAsyncAPI:
actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id
)
@ -548,7 +566,7 @@ class TestAsyncAPI:
assert job_manager.stop_job(job_id) is True
# Send run signal to unblock run function
ray.get(start_signal_actor.send.remote())
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
@ -572,7 +590,7 @@ class TestAsyncAPI:
actor = job_manager._get_actor_for_job(job_id)
ray.kill(actor, no_restart=True)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id
)
@ -590,7 +608,7 @@ class TestAsyncAPI:
assert psutil.pid_exists(pid), "driver subprocess should be running"
assert job_manager.stop_job(job_id) is True
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
@ -628,7 +646,8 @@ class TestTailLogs:
# TODO(edoakes): check we get no logs before actor starts (not sure
# how to timeout the iterator call).
assert job_manager.get_job_status(job_id) == JobStatus.PENDING
job_status = await job_manager.get_job_status(job_id)
assert job_status == JobStatus.PENDING
# Signal job to start.
ray.get(start_signal_actor.send.remote())
@ -645,7 +664,7 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
@ -666,7 +685,7 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_failed, job_manager=job_manager, job_id=job_id
)
@ -686,7 +705,7 @@ class TestTailLogs:
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
print(lines, end="")
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_stopped, job_manager=job_manager, job_id=job_id
)
@ -704,7 +723,7 @@ while True:
stream_logs_cmd = f'python -c "{stream_logs_script}"'
job_id = job_manager.submit_job(entrypoint=stream_logs_cmd)
job_id = await job_manager.submit_job(entrypoint=stream_logs_cmd)
await async_wait_for_condition(
lambda: "STREAMED" in job_manager.get_job_logs(job_id)
)
@ -726,9 +745,9 @@ async def test_bootstrap_address(job_manager, monkeypatch):
'python -c"' "import os;" "import ray;" "ray.init();" "print('SUCCESS!');" '"'
)
job_id = job_manager.submit_job(entrypoint=print_ray_address_cmd)
job_id = await job_manager.submit_job(entrypoint=print_ray_address_cmd)
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
assert "SUCCESS!" in job_manager.get_job_logs(job_id)
@ -751,8 +770,8 @@ async def test_job_runs_with_no_resources_available(job_manager):
# Check that the job starts up properly even with no CPUs available.
# The job won't exit until it has a CPU available because it waits for
# a task.
job_id = job_manager.submit_job(entrypoint=f"python {script_path}")
await async_wait_for_condition(
job_id = await job_manager.submit_job(entrypoint=f"python {script_path}")
await async_wait_for_condition_async_predicate(
check_job_running, job_manager=job_manager, job_id=job_id
)
await async_wait_for_condition(
@ -763,7 +782,7 @@ async def test_job_runs_with_no_resources_available(job_manager):
ray.get(hang_signal_actor.send.remote())
# Check the job succeeds now that resources are available.
await async_wait_for_condition(
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)
await async_wait_for_condition(

View file

@ -88,7 +88,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._gcs_actor_info_stub = None
self._dashboard_head = dashboard_head
assert _internal_kv_initialized()
self._job_info_client = JobInfoStorageClient()
self._job_info_client = None
# For offloading CPU intensive work.
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2, thread_name_prefix="api_head"
@ -269,11 +269,11 @@ class APIHead(dashboard_utils.DashboardHeadModule):
timestamp=datetime.now().timestamp(),
)
def _get_job_info(self, metadata: Dict[str, str]) -> Optional[JobInfo]:
async def _get_job_info(self, metadata: Dict[str, str]) -> Optional[JobInfo]:
# If a job submission ID has been added to a job, the status is
# guaranteed to be returned.
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
return self._job_info_client.get_info(job_submission_id)
return await self._job_info_client.get_info(job_submission_id)
async def get_job_info(self):
"""Return info for each job. Here a job is a Ray driver."""
@ -291,7 +291,7 @@ class APIHead(dashboard_utils.DashboardHeadModule):
job_table_entry.config.runtime_env_info.serialized_runtime_env
),
}
info = self._get_job_info(metadata)
info = await self._get_job_info(metadata)
entry = {
"status": None if info is None else info.status,
"status_message": None if info is None else info.message,
@ -308,8 +308,11 @@ class APIHead(dashboard_utils.DashboardHeadModule):
"""Info for Ray job submission. Here a job can have 0 or many drivers."""
jobs = {}
for job_submission_id, job_info in self._job_info_client.get_all_jobs().items():
fetched_jobs = await self._job_info_client.get_all_jobs()
for (
job_submission_id,
job_info,
) in fetched_jobs.items():
if job_info is not None:
entry = {
"job_submission_id": job_submission_id,
@ -428,6 +431,12 @@ class APIHead(dashboard_utils.DashboardHeadModule):
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel
)
# Lazily constructed because dashboard_head's gcs_aio_client
# is lazily constructed
if not self._job_info_client:
self._job_info_client = JobInfoStorageClient(
self._dashboard_head.gcs_aio_client
)
@staticmethod
def is_minimal_module():

View file

@ -267,7 +267,7 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
@RateLimitedModule.enforce_max_concurrent_calls
async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response:
try:
result = self._state_api.list_jobs(option=self._options_from_req(req))
result = await self._state_api.list_jobs(option=self._options_from_req(req))
return self._reply(
success=True,
error_message="",
@ -432,7 +432,9 @@ class StateHead(dashboard_utils.DashboardHeadModule, RateLimitedModule):
async def run(self, server):
gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._state_api_data_source_client = StateDataSourceClient(gcs_channel)
self._state_api_data_source_client = StateDataSourceClient(
gcs_channel, self._dashboard_head.gcs_aio_client
)
self._state_api = StateAPIManager(self._state_api_data_source_client)
self._log_api = LogsManager(self._state_api_data_source_client)

View file

@ -330,11 +330,11 @@ class StateAPIManager:
num_filtered=num_filtered,
)
def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
async def list_jobs(self, *, option: ListApiOptions) -> ListApiResponse:
# TODO(sang): Support limit & timeout & async calls.
try:
result = []
job_info = self._client.get_job_info()
job_info = await self._client.get_job_info()
for job_id, data in job_info.items():
data = asdict(data)
data["job_id"] = job_id

View file

@ -10,6 +10,7 @@ from grpc.aio._call import UnaryStreamCall
import ray
import ray.dashboard.modules.log.log_consts as log_consts
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoReply,
@ -138,12 +139,12 @@ class StateDataSourceClient:
- throw a ValueError if it cannot find the source.
"""
def __init__(self, gcs_channel: grpc.aio.Channel):
def __init__(self, gcs_channel: grpc.aio.Channel, gcs_aio_client: GcsAioClient):
self.register_gcs_client(gcs_channel)
self._raylet_stubs = {}
self._runtime_env_agent_stub = {}
self._log_agent_stub = {}
self._job_client = JobInfoStorageClient()
self._job_client = JobInfoStorageClient(gcs_aio_client)
self._id_id_map = IdToIpMap()
def register_gcs_client(self, gcs_channel: grpc.aio.Channel):
@ -256,11 +257,11 @@ class StateDataSourceClient:
)
return reply
def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
async def get_job_info(self) -> Optional[Dict[str, JobInfo]]:
# Cannot use @handle_grpc_network_errors because async def is not supported yet.
# TODO(sang): Support timeout & make it async
try:
return self._job_client.get_all_jobs()
return await self._job_client.get_all_jobs()
except grpc.aio.AioRpcError as e:
if (
e.code == grpc.StatusCode.DEADLINE_EXCEEDED

View file

@ -6,6 +6,7 @@ from typing import List, Tuple
from unittest.mock import MagicMock
import pytest
from ray._private.gcs_utils import GcsAioClient
import yaml
from click.testing import CliRunner
@ -97,7 +98,7 @@ from ray.experimental.state.state_manager import IdToIpMap, StateDataSourceClien
from ray.job_submission import JobSubmissionClient
from ray.runtime_env import RuntimeEnv
if sys.version_info > (3, 7, 0):
if sys.version_info >= (3, 8, 0):
from unittest.mock import AsyncMock
else:
from asyncmock import AsyncMock
@ -1096,7 +1097,8 @@ async def test_state_data_source_client(ray_start_cluster):
gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
client = StateDataSourceClient(gcs_channel)
gcs_aio_client = GcsAioClient(address=cluster.address, nums_reconnect_retry=0)
client = StateDataSourceClient(gcs_channel, gcs_aio_client)
"""
Test actor
@ -1132,7 +1134,7 @@ async def test_state_data_source_client(ray_start_cluster):
# Entrypoint shell command to execute
entrypoint="ls",
)
result = client.get_job_info()
result = await client.get_job_info()
assert list(result.keys())[0] == job_id
assert isinstance(result, dict)
@ -1248,7 +1250,8 @@ async def test_state_data_source_client_limit_gcs_source(ray_start_cluster):
gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
client = StateDataSourceClient(gcs_channel)
gcs_aio_client = GcsAioClient(address=cluster.address, nums_reconnect_retry=0)
client = StateDataSourceClient(gcs_channel, gcs_aio_client)
"""
Test actor
@ -1299,7 +1302,8 @@ async def test_state_data_source_client_limit_distributed_sources(ray_start_clus
gcs_channel = ray._private.utils.init_grpc_channel(
cluster.address, GRPC_CHANNEL_OPTIONS, asynchronous=True
)
client = StateDataSourceClient(gcs_channel)
gcs_aio_client = GcsAioClient(address=cluster.address, nums_reconnect_retry=0)
client = StateDataSourceClient(gcs_channel, gcs_aio_client)
for node in ray.nodes():
node_id = node["NodeID"]
ip = node["NodeManagerAddress"]