From ce70b8b96ea21581a33475c4a405f1c5d99deaf5 Mon Sep 17 00:00:00 2001 From: Jialing He Date: Sat, 3 Sep 2022 18:42:02 +0800 Subject: [PATCH] [Job Submission][refactor 2/N] introduce job agent (#28203) --- dashboard/modules/job/job_agent.py | 97 ++++++ dashboard/modules/job/job_head.py | 201 ++++------- dashboard/modules/job/job_manager.py | 18 +- dashboard/modules/job/tests/test_job_agent.py | 318 ++++++++++++++++++ dashboard/modules/job/utils.py | 151 ++++++++- python/ray/_private/gcs_utils.py | 11 + 6 files changed, 661 insertions(+), 135 deletions(-) create mode 100644 dashboard/modules/job/job_agent.py create mode 100644 dashboard/modules/job/tests/test_job_agent.py diff --git a/dashboard/modules/job/job_agent.py b/dashboard/modules/job/job_agent.py new file mode 100644 index 000000000..f7e145df1 --- /dev/null +++ b/dashboard/modules/job/job_agent.py @@ -0,0 +1,97 @@ +import aiohttp +from aiohttp.web import Request, Response +import dataclasses +import json +import logging +import traceback + +import ray.dashboard.optional_utils as optional_utils +import ray.dashboard.utils as dashboard_utils +from ray.dashboard.modules.job.common import ( + JobSubmitRequest, + JobSubmitResponse, +) +from ray.dashboard.modules.job.job_manager import JobManager +from ray.dashboard.modules.job.utils import parse_and_validate_request, find_job_by_ids + + +routes = optional_utils.ClassMethodRouteTable +logger = logging.getLogger(__name__) + + +class JobAgent(dashboard_utils.DashboardAgentModule): + def __init__(self, dashboard_agent): + super().__init__(dashboard_agent) + self._job_manager = None + self._gcs_job_info_stub = None + + @routes.post("/api/job_agent/jobs/") + @optional_utils.init_ray_and_catch_exceptions() + async def submit_job(self, req: Request) -> Response: + result = await parse_and_validate_request(req, JobSubmitRequest) + # Request parsing failed, returned with Response object. + if isinstance(result, Response): + return result + else: + submit_request = result + + request_submission_id = submit_request.submission_id or submit_request.job_id + try: + submission_id = await self.get_job_manager().submit_job( + entrypoint=submit_request.entrypoint, + submission_id=request_submission_id, + runtime_env=submit_request.runtime_env, + metadata=submit_request.metadata, + _driver_on_current_node=False, + ) + + resp = JobSubmitResponse(job_id=submission_id, submission_id=submission_id) + except (TypeError, ValueError): + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPBadRequest.status_code, + ) + except Exception: + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPInternalServerError.status_code, + ) + + return Response( + text=json.dumps(dataclasses.asdict(resp)), + content_type="application/json", + status=aiohttp.web.HTTPOk.status_code, + ) + + @routes.get("/api/job_agent/jobs/{job_or_submission_id}") + @optional_utils.init_ray_and_catch_exceptions() + async def get_job_info(self, req: Request) -> Response: + job_or_submission_id = req.match_info["job_or_submission_id"] + + job = await find_job_by_ids( + self._dashboard_agent.gcs_aio_client, + self.get_job_manager(), + job_or_submission_id, + ) + if not job: + return Response( + text=f"Job {job_or_submission_id} does not exist", + status=aiohttp.web.HTTPNotFound.status_code, + ) + + return Response( + text=json.dumps(job.dict()), + content_type="application/json", + ) + + def get_job_manager(self): + if not self._job_manager: + self._job_manager = JobManager(self._dashboard_agent.gcs_aio_client) + return self._job_manager + + async def run(self, server): + pass + + @staticmethod + def is_minimal_module(): + return False diff --git a/dashboard/modules/job/job_head.py b/dashboard/modules/job/job_head.py index 1793f0ea9..8aaa1aa72 100644 --- a/dashboard/modules/job/job_head.py +++ b/dashboard/modules/job/job_head.py @@ -2,14 +2,12 @@ import dataclasses import json import logging import traceback -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple import aiohttp.web from aiohttp.web import Request, Response +from aiohttp.client import ClientResponse import ray -from ray._private import ray_constants import ray.dashboard.optional_utils as optional_utils import ray.dashboard.utils as dashboard_utils from ray._private.runtime_env.packaging import ( @@ -17,28 +15,27 @@ from ray._private.runtime_env.packaging import ( pin_runtime_env_uri, upload_package_to_gcs, ) -from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc from ray.dashboard.modules.job.common import ( http_uri_components_to_uri, - JobStatus, JobSubmitRequest, JobSubmitResponse, JobStopResponse, JobLogsResponse, - validate_request_type, - JOB_ID_METADATA_KEY, ) from ray.dashboard.modules.job.pydantic_models import ( - DriverInfo, JobDetails, JobType, ) +from ray.dashboard.modules.job.utils import ( + parse_and_validate_request, + get_driver_jobs, + find_job_by_ids, +) from ray.dashboard.modules.version import ( CURRENT_VERSION, VersionResponse, ) from ray.dashboard.modules.job.job_manager import JobManager -from ray.runtime_env import RuntimeEnv logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -46,65 +43,61 @@ logger.setLevel(logging.INFO) routes = optional_utils.ClassMethodRouteTable +class JobAgentSubmissionClient: + """A local client for submitting and interacting with jobs on a specific node + in the remote cluster. + Submits requests over HTTP to the job agent on the specific node using the REST API. + """ + + def __init__( + self, + dashboard_agent_address: str, + ): + self._address = dashboard_agent_address + self._session = aiohttp.ClientSession() + + async def _raise_error(self, r: ClientResponse): + status = r.status + error_text = await r.text() + raise RuntimeError(f"Request failed with status code {status}: {error_text}.") + + async def submit_job_internal(self, req: JobSubmitRequest) -> JobSubmitResponse: + + logger.debug(f"Submitting job with submission_id={req.submission_id}.") + + async with self._session.post( + self._address + "/api/job_agent/jobs/", json=dataclasses.asdict(req) + ) as resp: + + if resp.status == 200: + result_json = await resp.json() + return JobSubmitResponse(**result_json) + else: + await self._raise_error(resp) + + async def get_job_info(self, job_id: str) -> JobDetails: + async with self._session.get( + self._address + f"/api/job_agent/jobs/{job_id}" + ) as resp: + if resp.status == 200: + result_json = await resp.json() + return JobDetails(**result_json) + else: + await self._raise_error(resp) + + async def close(self, ignore_error=True): + try: + await self._session.close() + except Exception: + if not ignore_error: + raise + + class JobHead(dashboard_utils.DashboardHeadModule): def __init__(self, dashboard_head): super().__init__(dashboard_head) self._dashboard_head = dashboard_head self._job_manager = None - self._gcs_job_info_stub = None - - async def _parse_and_validate_request( - self, req: Request, request_type: dataclass - ) -> Any: - """Parse request and cast to request type. If parsing failed, return a - Response object with status 400 and stacktrace instead. - """ - try: - return validate_request_type(await req.json(), request_type) - except Exception as e: - logger.info(f"Got invalid request type: {e}") - return Response( - text=traceback.format_exc(), - status=aiohttp.web.HTTPBadRequest.status_code, - ) - - async def find_job_by_ids(self, job_or_submission_id: str) -> Optional[JobDetails]: - """ - Attempts to find the job with a given submission_id or job id. - """ - # First try to find by job_id - driver_jobs, submission_job_drivers = await self._get_driver_jobs() - job = driver_jobs.get(job_or_submission_id) - if job: - return job - # Try to find a driver with the given id - submission_id = next( - ( - id - for id, driver in submission_job_drivers.items() - if driver.id == job_or_submission_id - ), - None, - ) - - if not submission_id: - # If we didn't find a driver with the given id, - # then lets try to search for a submission with given id - submission_id = job_or_submission_id - - job_info = await self._job_manager.get_job_info(submission_id) - if job_info: - driver = submission_job_drivers.get(submission_id) - job = JobDetails( - **dataclasses.asdict(job_info), - submission_id=submission_id, - job_id=driver.id if driver else None, - driver_info=driver, - type=JobType.SUBMISSION, - ) - return job - - return None @routes.get("/api/version") async def get_version(self, req: Request) -> Response: @@ -167,7 +160,7 @@ class JobHead(dashboard_utils.DashboardHeadModule): @routes.post("/api/jobs/") @optional_utils.init_ray_and_catch_exceptions() async def submit_job(self, req: Request) -> Response: - result = await self._parse_and_validate_request(req, JobSubmitRequest) + result = await parse_and_validate_request(req, JobSubmitRequest) # Request parsing failed, returned with Response object. if isinstance(result, Response): return result @@ -206,7 +199,9 @@ class JobHead(dashboard_utils.DashboardHeadModule): @optional_utils.init_ray_and_catch_exceptions() async def stop_job(self, req: Request) -> Response: job_or_submission_id = req.match_info["job_or_submission_id"] - job = await self.find_job_by_ids(job_or_submission_id) + job = await find_job_by_ids( + self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id + ) if not job: return Response( text=f"Job {job_or_submission_id} does not exist", @@ -235,7 +230,9 @@ class JobHead(dashboard_utils.DashboardHeadModule): @optional_utils.init_ray_and_catch_exceptions() async def get_job_info(self, req: Request) -> Response: job_or_submission_id = req.match_info["job_or_submission_id"] - job = await self.find_job_by_ids(job_or_submission_id) + job = await find_job_by_ids( + self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id + ) if not job: return Response( text=f"Job {job_or_submission_id} does not exist", @@ -250,7 +247,9 @@ class JobHead(dashboard_utils.DashboardHeadModule): @routes.get("/api/jobs/") @optional_utils.init_ray_and_catch_exceptions() async def list_jobs(self, req: Request) -> Response: - driver_jobs, submission_job_drivers = await self._get_driver_jobs() + driver_jobs, submission_job_drivers = await get_driver_jobs( + self._dashboard_head.gcs_aio_client + ) submission_jobs = await self._job_manager.list_jobs() submission_jobs = [ @@ -275,67 +274,13 @@ class JobHead(dashboard_utils.DashboardHeadModule): content_type="application/json", ) - async def _get_driver_jobs( - self, - ) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]: - """Returns a tuple of dictionaries related to drivers. - - The first dictionary contains all driver jobs and is keyed by the job's id. - The second dictionary contains drivers that belong to submission jobs. - It's keyed by the submission job's submission id. - Only the last driver of a submission job is returned. - """ - request = gcs_service_pb2.GetAllJobInfoRequest() - reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5) - - jobs = {} - submission_job_drivers = {} - for job_table_entry in reply.job_info_list: - if job_table_entry.config.ray_namespace.startswith( - ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX - ): - # Skip jobs in any _ray_internal_ namespace - continue - job_id = job_table_entry.job_id.hex() - metadata = dict(job_table_entry.config.metadata) - job_submission_id = metadata.get(JOB_ID_METADATA_KEY) - if not job_submission_id: - driver = DriverInfo( - id=job_id, - node_ip_address=job_table_entry.driver_ip_address, - pid=job_table_entry.driver_pid, - ) - job = JobDetails( - job_id=job_id, - type=JobType.DRIVER, - status=JobStatus.SUCCEEDED - if job_table_entry.is_dead - else JobStatus.RUNNING, - entrypoint="", - start_time=job_table_entry.start_time, - end_time=job_table_entry.end_time, - metadata=metadata, - runtime_env=RuntimeEnv.deserialize( - job_table_entry.config.runtime_env_info.serialized_runtime_env - ).to_dict(), - driver_info=driver, - ) - jobs[job_id] = job - else: - driver = DriverInfo( - id=job_id, - node_ip_address=job_table_entry.driver_ip_address, - pid=job_table_entry.driver_pid, - ) - submission_job_drivers[job_submission_id] = driver - - return jobs, submission_job_drivers - @routes.get("/api/jobs/{job_or_submission_id}/logs") @optional_utils.init_ray_and_catch_exceptions() async def get_job_logs(self, req: Request) -> Response: job_or_submission_id = req.match_info["job_or_submission_id"] - job = await self.find_job_by_ids(job_or_submission_id) + job = await find_job_by_ids( + self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id + ) if not job: return Response( text=f"Job {job_or_submission_id} does not exist", @@ -357,7 +302,9 @@ class JobHead(dashboard_utils.DashboardHeadModule): @optional_utils.init_ray_and_catch_exceptions() async def tail_job_logs(self, req: Request) -> Response: job_or_submission_id = req.match_info["job_or_submission_id"] - job = await self.find_job_by_ids(job_or_submission_id) + job = await find_job_by_ids( + self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id + ) if not job: return Response( text=f"Job {job_or_submission_id} does not exist", @@ -380,10 +327,6 @@ class JobHead(dashboard_utils.DashboardHeadModule): if not self._job_manager: self._job_manager = JobManager(self._dashboard_head.gcs_aio_client) - self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( - self._dashboard_head.aiogrpc_gcs_channel - ) - @staticmethod def is_minimal_module(): return False diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index 6530185af..22ca61e0b 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -26,6 +26,7 @@ from ray.dashboard.modules.job.common import ( from ray.dashboard.modules.job.utils import file_tail_iterator from ray.exceptions import RuntimeEnvSetupError from ray.job_submission import JobStatus +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy logger = logging.getLogger(__name__) @@ -469,6 +470,7 @@ class JobManager: runtime_env: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, str]] = None, _start_signal_actor: Optional[ActorHandle] = None, + _driver_on_current_node: bool = True, ) -> str: """ Job execution happens asynchronously. @@ -493,6 +495,8 @@ class JobManager: _start_signal_actor: Used in testing only to capture state transitions between PENDING -> RUNNING. Regular user shouldn't need this. + _driver_on_current_node: whether force driver run on current node, + the default value is True. Returns: job_id: Generated uuid for further job management. Only valid @@ -517,15 +521,19 @@ class JobManager: # returns immediately and we can catch errors with the actor starting # up. try: + scheduling_strategy = "DEFAULT" + if _driver_on_current_node: + # If JobManager is created by dashboard server + # running on headnode, same for job supervisor actors scheduled + scheduling_strategy = NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().node_id, + soft=False, + ) supervisor = self._supervisor_actor_cls.options( lifetime="detached", name=self.JOB_ACTOR_NAME_TEMPLATE.format(job_id=submission_id), num_cpus=0, - # Currently we assume JobManager is created by dashboard server - # running on headnode, same for job supervisor actors scheduled - resources={ - self._get_current_node_resource_key(): 0.001, - }, + scheduling_strategy=scheduling_strategy, runtime_env=self._get_supervisor_runtime_env(runtime_env), ).remote(submission_id, entrypoint, metadata or {}, self._gcs_address) supervisor.run.remote(_start_signal_actor=_start_signal_actor) diff --git a/dashboard/modules/job/tests/test_job_agent.py b/dashboard/modules/job/tests/test_job_agent.py new file mode 100644 index 000000000..1df940ceb --- /dev/null +++ b/dashboard/modules/job/tests/test_job_agent.py @@ -0,0 +1,318 @@ +import asyncio +import logging +import os +import shutil +import sys +import tempfile +import time +from pathlib import Path + +import pytest +from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig +import yaml + +from ray._private.runtime_env.packaging import Protocol, parse_uri +from ray._private.ray_constants import DEFAULT_DASHBOARD_AGENT_LISTEN_PORT +from ray._private.test_utils import ( + chdir, + format_web_url, + wait_until_server_available, +) +from ray.dashboard.modules.job.common import JobSubmitRequest +from ray.dashboard.modules.job.utils import validate_request_type +from ray.dashboard.tests.conftest import * # noqa +from ray.job_submission import JobStatus +from ray.tests.conftest import _ray_start +from ray.dashboard.modules.job.job_head import JobAgentSubmissionClient + +# This test requires you have AWS credentials set up (any AWS credentials will +# do, this test only accesses a public bucket). + +logger = logging.getLogger(__name__) + +DRIVER_SCRIPT_DIR = os.path.join(os.path.dirname(__file__), "subprocess_driver_scripts") +EVENT_LOOP = asyncio.get_event_loop() + + +@pytest.fixture +def job_sdk_client(): + with _ray_start(include_dashboard=True, num_cpus=1) as ctx: + ip, port = ctx.address_info["webui_url"].split(":") + agent_address = f"{ip}:{DEFAULT_DASHBOARD_AGENT_LISTEN_PORT}" + assert wait_until_server_available(agent_address) + yield JobAgentSubmissionClient(format_web_url(agent_address)) + + +async def _check_job( + client: JobAgentSubmissionClient, job_id: str, status: JobStatus, timeout: int = 10 +) -> bool: + async def _check(): + result = await client.get_job_info(job_id) + return result.status == status + + st = time.time() + while time.time() <= timeout + st: + res = await _check() + if res: + return True + await asyncio.sleep(0.1) + return False + + +@pytest.fixture( + scope="module", + params=[ + "no_working_dir", + "local_working_dir", + "s3_working_dir", + "local_py_modules", + "working_dir_and_local_py_modules_whl", + "local_working_dir_zip", + "pip_txt", + "conda_yaml", + "local_py_modules", + ], +) +def runtime_env_option(request): + import_in_task_script = """ +import ray +ray.init(address="auto") + +@ray.remote +def f(): + import pip_install_test + +ray.get(f.remote()) +""" + if request.param == "no_working_dir": + yield { + "runtime_env": {}, + "entrypoint": "echo hello", + "expected_logs": "hello\n", + } + elif request.param in { + "local_working_dir", + "local_working_dir_zip", + "local_py_modules", + "working_dir_and_local_py_modules_whl", + }: + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) + + hello_file = path / "test.py" + with hello_file.open(mode="w") as f: + f.write("from test_module import run_test\n") + f.write("print(run_test())") + + module_path = path / "test_module" + module_path.mkdir(parents=True) + + test_file = module_path / "test.py" + with test_file.open(mode="w") as f: + f.write("def run_test():\n") + f.write(" return 'Hello from test_module!'\n") # noqa: Q000 + + init_file = module_path / "__init__.py" + with init_file.open(mode="w") as f: + f.write("from test_module.test import run_test\n") + + if request.param == "local_working_dir": + yield { + "runtime_env": {"working_dir": tmp_dir}, + "entrypoint": "python test.py", + "expected_logs": "Hello from test_module!\n", + } + elif request.param == "local_working_dir_zip": + local_zipped_dir = shutil.make_archive( + os.path.join(tmp_dir, "test"), "zip", tmp_dir + ) + yield { + "runtime_env": {"working_dir": local_zipped_dir}, + "entrypoint": "python test.py", + "expected_logs": "Hello from test_module!\n", + } + elif request.param == "local_py_modules": + yield { + "runtime_env": {"py_modules": [str(Path(tmp_dir) / "test_module")]}, + "entrypoint": ( + "python -c 'import test_module;" + "print(test_module.run_test())'" + ), + "expected_logs": "Hello from test_module!\n", + } + elif request.param == "working_dir_and_local_py_modules_whl": + yield { + "runtime_env": { + "working_dir": "s3://runtime-env-test/script_runtime_env.zip", + "py_modules": [ + Path(os.path.dirname(__file__)) + / "pip_install_test-0.5-py3-none-any.whl" + ], + }, + "entrypoint": ( + "python script.py && python -c 'import pip_install_test'" + ), + "expected_logs": ( + "Executing main() from script.py !!\n" + "Good job! You installed a pip module." + ), + } + else: + raise ValueError(f"Unexpected pytest fixture option {request.param}") + elif request.param == "s3_working_dir": + yield { + "runtime_env": { + "working_dir": "s3://runtime-env-test/script_runtime_env.zip", + }, + "entrypoint": "python script.py", + "expected_logs": "Executing main() from script.py !!\n", + } + elif request.param == "pip_txt": + with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir): + pip_list = ["pip-install-test==0.5"] + relative_filepath = "requirements.txt" + pip_file = Path(relative_filepath) + pip_file.write_text("\n".join(pip_list)) + runtime_env = {"pip": {"packages": relative_filepath, "pip_check": False}} + yield { + "runtime_env": runtime_env, + "entrypoint": ( + f"python -c 'import pip_install_test' && " + f"python -c '{import_in_task_script}'" + ), + "expected_logs": "Good job! You installed a pip module.", + } + elif request.param == "conda_yaml": + with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir): + conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]} + relative_filepath = "environment.yml" + conda_file = Path(relative_filepath) + conda_file.write_text(yaml.dump(conda_dict)) + runtime_env = {"conda": relative_filepath} + + yield { + "runtime_env": runtime_env, + "entrypoint": f"python -c '{import_in_task_script}'", + # TODO(architkulkarni): Uncomment after #22968 is fixed. + # "entrypoint": "python -c 'import pip_install_test'", + "expected_logs": "Good job! You installed a pip module.", + } + else: + assert False, f"Unrecognized option: {request.param}." + + +@pytest.mark.asyncio +async def test_submit_job(job_sdk_client, runtime_env_option, monkeypatch): + # This flag allows for local testing of runtime env conda functionality + # without needing a built Ray wheel. Rather than insert the link to the + # wheel into the conda spec, it links to the current Python site. + monkeypatch.setenv("RAY_RUNTIME_ENV_LOCAL_DEV_MODE", "1") + + client = job_sdk_client + + need_upload = False + working_dir = runtime_env_option["runtime_env"].get("working_dir", None) + py_modules = runtime_env_option["runtime_env"].get("py_modules", []) + + def _need_upload(path): + try: + protocol, _ = parse_uri(path) + if protocol == Protocol.GCS: + return True + except ValueError: + # local file, need upload + return True + return False + + if working_dir: + need_upload = need_upload or _need_upload(working_dir) + if py_modules: + need_upload = need_upload or any( + [_need_upload(str(py_module)) for py_module in py_modules] + ) + + # TODO(Catch-Bull): delete this after we implemented + # `upload package` and `get package` + if need_upload: + # not implemented `upload package` yet. + print("Skip test, because of need upload") + return + + runtime_env = RuntimeEnv(**runtime_env_option["runtime_env"]).to_dict() + request = validate_request_type( + {"runtime_env": runtime_env, "entrypoint": runtime_env_option["entrypoint"]}, + JobSubmitRequest, + ) + + submit_result = await client.submit_job_internal(request) + job_id = submit_result.submission_id + + check_result = await _check_job( + client=client, job_id=job_id, status=JobStatus.SUCCEEDED, timeout=120 + ) + assert check_result + + # TODO(Catch-Bull): delete this after we implemented + # `get_job_logs` + # not implemented `get_job_logs` yet. + print("Skip test, because of need get job logs") + return + logs = client.get_job_logs(job_id) + assert runtime_env_option["expected_logs"] in logs + + +@pytest.mark.asyncio +async def test_timeout(job_sdk_client): + client = job_sdk_client + + runtime_env = RuntimeEnv( + pip={ + "packages": ["tensorflow", "requests", "botocore", "torch"], + "pip_check": False, + "pip_version": "==22.0.2;python_version=='3.8.11'", + }, + config=RuntimeEnvConfig(setup_timeout_seconds=1), + ).to_dict() + request = validate_request_type( + {"runtime_env": runtime_env, "entrypoint": "echo hello"}, + JobSubmitRequest, + ) + + submit_result = await client.submit_job_internal(request) + job_id = submit_result.submission_id + + check_result = await _check_job( + client=client, job_id=job_id, status=JobStatus.FAILED, timeout=10 + ) + assert check_result + + data = await client.get_job_info(job_id) + assert "Failed to set up runtime environment" in data.message + assert "Timeout" in data.message + assert "consider increasing `setup_timeout_seconds`" in data.message + + +@pytest.mark.asyncio +async def test_runtime_env_setup_failure(job_sdk_client): + client = job_sdk_client + + runtime_env = RuntimeEnv(working_dir="s3://does_not_exist.zip").to_dict() + request = validate_request_type( + {"runtime_env": runtime_env, "entrypoint": "echo hello"}, + JobSubmitRequest, + ) + + submit_result = await client.submit_job_internal(request) + job_id = submit_result.submission_id + + check_result = await _check_job( + client=client, job_id=job_id, status=JobStatus.FAILED, timeout=10 + ) + assert check_result + + data = await client.get_job_info(job_id) + assert "Failed to set up runtime environment" in data.message + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/dashboard/modules/job/utils.py b/dashboard/modules/job/utils.py index 6da2a6062..dc438508e 100644 --- a/dashboard/modules/job/utils.py +++ b/dashboard/modules/job/utils.py @@ -1,6 +1,40 @@ +import dataclasses import logging import os -from typing import Iterator, List, Optional +import traceback +from dataclasses import dataclass +from typing import Iterator, List, Optional, Any, Dict, Tuple + +try: + # package `aiohttp` is not in ray's minimal dependencies + import aiohttp + from aiohttp.web import Request, Response +except Exception: + aiohttp = None + Request = None + Response = None + +from ray._private import ray_constants +from ray._private.gcs_utils import GcsAioClient +from ray.dashboard.modules.job.common import validate_request_type + +try: + # package `pydantic` is not in ray's minimal dependencies + from ray.dashboard.modules.job.pydantic_models import ( + DriverInfo, + JobDetails, + JobType, + ) +except Exception: + DriverInfo = None + JobDetails = None + JobType = None + +from ray.dashboard.modules.job.common import ( + JobStatus, + JOB_ID_METADATA_KEY, +) +from ray.runtime_env import RuntimeEnv logger = logging.getLogger(__name__) @@ -64,3 +98,118 @@ def file_tail_iterator(path: str) -> Iterator[Optional[List[str]]]: lines = [] chunk_char_count = 0 curr_line = None + + +async def parse_and_validate_request(req: Request, request_type: dataclass) -> Any: + """Parse request and cast to request type. If parsing failed, return a + Response object with status 400 and stacktrace instead. + """ + import aiohttp + + try: + return validate_request_type(await req.json(), request_type) + except Exception as e: + logger.info(f"Got invalid request type: {e}") + return Response( + text=traceback.format_exc(), + status=aiohttp.web.HTTPBadRequest.status_code, + ) + + +async def get_driver_jobs( + gcs_aio_client: GcsAioClient, +) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]: + """Returns a tuple of dictionaries related to drivers. + + The first dictionary contains all driver jobs and is keyed by the job's id. + The second dictionary contains drivers that belong to submission jobs. + It's keyed by the submission job's submission id. + Only the last driver of a submission job is returned. + """ + reply = await gcs_aio_client.get_all_job_info() + + jobs = {} + submission_job_drivers = {} + for job_table_entry in reply.job_info_list: + if job_table_entry.config.ray_namespace.startswith( + ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX + ): + # Skip jobs in any _ray_internal_ namespace + continue + job_id = job_table_entry.job_id.hex() + metadata = dict(job_table_entry.config.metadata) + job_submission_id = metadata.get(JOB_ID_METADATA_KEY) + if not job_submission_id: + driver = DriverInfo( + id=job_id, + node_ip_address=job_table_entry.driver_ip_address, + pid=job_table_entry.driver_pid, + ) + job = JobDetails( + job_id=job_id, + type=JobType.DRIVER, + status=JobStatus.SUCCEEDED + if job_table_entry.is_dead + else JobStatus.RUNNING, + entrypoint="", + start_time=job_table_entry.start_time, + end_time=job_table_entry.end_time, + metadata=metadata, + runtime_env=RuntimeEnv.deserialize( + job_table_entry.config.runtime_env_info.serialized_runtime_env + ).to_dict(), + driver_info=driver, + ) + jobs[job_id] = job + else: + driver = DriverInfo( + id=job_id, + node_ip_address=job_table_entry.driver_ip_address, + pid=job_table_entry.driver_pid, + ) + submission_job_drivers[job_submission_id] = driver + + return jobs, submission_job_drivers + + +async def find_job_by_ids( + gcs_aio_client: GcsAioClient, + job_manager: "JobManager", # noqa: F821 + job_or_submission_id: str, +) -> Optional[JobDetails]: + """ + Attempts to find the job with a given submission_id or job id. + """ + # First try to find by job_id + driver_jobs, submission_job_drivers = await get_driver_jobs(gcs_aio_client) + job = driver_jobs.get(job_or_submission_id) + if job: + return job + # Try to find a driver with the given id + submission_id = next( + ( + id + for id, driver in submission_job_drivers.items() + if driver.id == job_or_submission_id + ), + None, + ) + + if not submission_id: + # If we didn't find a driver with the given id, + # then lets try to search for a submission with given id + submission_id = job_or_submission_id + + job_info = await job_manager.get_job_info(submission_id) + if job_info: + driver = submission_job_drivers.get(submission_id) + job = JobDetails( + **dataclasses.asdict(job_info), + submission_id=submission_id, + job_id=driver.id if driver else None, + driver_info=driver, + type=JobType.SUBMISSION, + ) + return job + + return None diff --git a/python/ray/_private/gcs_utils.py b/python/ray/_private/gcs_utils.py index afe0a0e12..e16b3c178 100644 --- a/python/ray/_private/gcs_utils.py +++ b/python/ray/_private/gcs_utils.py @@ -415,6 +415,9 @@ class GcsAioClient: self._heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub( self._channel.channel() ) + self._job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( + self._channel.channel() + ) @_auto_reconnect async def check_alive( @@ -521,6 +524,14 @@ class GcsAioClient: f"due to error {reply.status.message}" ) + @_auto_reconnect + async def get_all_job_info( + self, timeout: Optional[float] = None + ) -> gcs_service_pb2.GetAllJobInfoReply: + req = gcs_service_pb2.GetAllJobInfoRequest() + reply = await self._job_info_stub.GetAllJobInfo(req, timeout=timeout) + return reply + def use_gcs_for_bootstrap(): """In the current version of Ray, we always use the GCS to bootstrap.