[Job Submission][refactor 2/N] introduce job agent (#28203)

This commit is contained in:
Jialing He 2022-09-03 18:42:02 +08:00 committed by GitHub
parent a31be7cef1
commit ce70b8b96e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 661 additions and 135 deletions

View file

@ -0,0 +1,97 @@
import aiohttp
from aiohttp.web import Request, Response
import dataclasses
import json
import logging
import traceback
import ray.dashboard.optional_utils as optional_utils
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.modules.job.common import (
JobSubmitRequest,
JobSubmitResponse,
)
from ray.dashboard.modules.job.job_manager import JobManager
from ray.dashboard.modules.job.utils import parse_and_validate_request, find_job_by_ids
routes = optional_utils.ClassMethodRouteTable
logger = logging.getLogger(__name__)
class JobAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
self._job_manager = None
self._gcs_job_info_stub = None
@routes.post("/api/job_agent/jobs/")
@optional_utils.init_ray_and_catch_exceptions()
async def submit_job(self, req: Request) -> Response:
result = await parse_and_validate_request(req, JobSubmitRequest)
# Request parsing failed, returned with Response object.
if isinstance(result, Response):
return result
else:
submit_request = result
request_submission_id = submit_request.submission_id or submit_request.job_id
try:
submission_id = await self.get_job_manager().submit_job(
entrypoint=submit_request.entrypoint,
submission_id=request_submission_id,
runtime_env=submit_request.runtime_env,
metadata=submit_request.metadata,
_driver_on_current_node=False,
)
resp = JobSubmitResponse(job_id=submission_id, submission_id=submission_id)
except (TypeError, ValueError):
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code,
)
except Exception:
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code,
)
return Response(
text=json.dumps(dataclasses.asdict(resp)),
content_type="application/json",
status=aiohttp.web.HTTPOk.status_code,
)
@routes.get("/api/job_agent/jobs/{job_or_submission_id}")
@optional_utils.init_ray_and_catch_exceptions()
async def get_job_info(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await find_job_by_ids(
self._dashboard_agent.gcs_aio_client,
self.get_job_manager(),
job_or_submission_id,
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
status=aiohttp.web.HTTPNotFound.status_code,
)
return Response(
text=json.dumps(job.dict()),
content_type="application/json",
)
def get_job_manager(self):
if not self._job_manager:
self._job_manager = JobManager(self._dashboard_agent.gcs_aio_client)
return self._job_manager
async def run(self, server):
pass
@staticmethod
def is_minimal_module():
return False

View file

@ -2,14 +2,12 @@ import dataclasses
import json
import logging
import traceback
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import aiohttp.web
from aiohttp.web import Request, Response
from aiohttp.client import ClientResponse
import ray
from ray._private import ray_constants
import ray.dashboard.optional_utils as optional_utils
import ray.dashboard.utils as dashboard_utils
from ray._private.runtime_env.packaging import (
@ -17,28 +15,27 @@ from ray._private.runtime_env.packaging import (
pin_runtime_env_uri,
upload_package_to_gcs,
)
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
from ray.dashboard.modules.job.common import (
http_uri_components_to_uri,
JobStatus,
JobSubmitRequest,
JobSubmitResponse,
JobStopResponse,
JobLogsResponse,
validate_request_type,
JOB_ID_METADATA_KEY,
)
from ray.dashboard.modules.job.pydantic_models import (
DriverInfo,
JobDetails,
JobType,
)
from ray.dashboard.modules.job.utils import (
parse_and_validate_request,
get_driver_jobs,
find_job_by_ids,
)
from ray.dashboard.modules.version import (
CURRENT_VERSION,
VersionResponse,
)
from ray.dashboard.modules.job.job_manager import JobManager
from ray.runtime_env import RuntimeEnv
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@ -46,65 +43,61 @@ logger.setLevel(logging.INFO)
routes = optional_utils.ClassMethodRouteTable
class JobAgentSubmissionClient:
"""A local client for submitting and interacting with jobs on a specific node
in the remote cluster.
Submits requests over HTTP to the job agent on the specific node using the REST API.
"""
def __init__(
self,
dashboard_agent_address: str,
):
self._address = dashboard_agent_address
self._session = aiohttp.ClientSession()
async def _raise_error(self, r: ClientResponse):
status = r.status
error_text = await r.text()
raise RuntimeError(f"Request failed with status code {status}: {error_text}.")
async def submit_job_internal(self, req: JobSubmitRequest) -> JobSubmitResponse:
logger.debug(f"Submitting job with submission_id={req.submission_id}.")
async with self._session.post(
self._address + "/api/job_agent/jobs/", json=dataclasses.asdict(req)
) as resp:
if resp.status == 200:
result_json = await resp.json()
return JobSubmitResponse(**result_json)
else:
await self._raise_error(resp)
async def get_job_info(self, job_id: str) -> JobDetails:
async with self._session.get(
self._address + f"/api/job_agent/jobs/{job_id}"
) as resp:
if resp.status == 200:
result_json = await resp.json()
return JobDetails(**result_json)
else:
await self._raise_error(resp)
async def close(self, ignore_error=True):
try:
await self._session.close()
except Exception:
if not ignore_error:
raise
class JobHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._dashboard_head = dashboard_head
self._job_manager = None
self._gcs_job_info_stub = None
async def _parse_and_validate_request(
self, req: Request, request_type: dataclass
) -> Any:
"""Parse request and cast to request type. If parsing failed, return a
Response object with status 400 and stacktrace instead.
"""
try:
return validate_request_type(await req.json(), request_type)
except Exception as e:
logger.info(f"Got invalid request type: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code,
)
async def find_job_by_ids(self, job_or_submission_id: str) -> Optional[JobDetails]:
"""
Attempts to find the job with a given submission_id or job id.
"""
# First try to find by job_id
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
job = driver_jobs.get(job_or_submission_id)
if job:
return job
# Try to find a driver with the given id
submission_id = next(
(
id
for id, driver in submission_job_drivers.items()
if driver.id == job_or_submission_id
),
None,
)
if not submission_id:
# If we didn't find a driver with the given id,
# then lets try to search for a submission with given id
submission_id = job_or_submission_id
job_info = await self._job_manager.get_job_info(submission_id)
if job_info:
driver = submission_job_drivers.get(submission_id)
job = JobDetails(
**dataclasses.asdict(job_info),
submission_id=submission_id,
job_id=driver.id if driver else None,
driver_info=driver,
type=JobType.SUBMISSION,
)
return job
return None
@routes.get("/api/version")
async def get_version(self, req: Request) -> Response:
@ -167,7 +160,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
@routes.post("/api/jobs/")
@optional_utils.init_ray_and_catch_exceptions()
async def submit_job(self, req: Request) -> Response:
result = await self._parse_and_validate_request(req, JobSubmitRequest)
result = await parse_and_validate_request(req, JobSubmitRequest)
# Request parsing failed, returned with Response object.
if isinstance(result, Response):
return result
@ -206,7 +199,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
@optional_utils.init_ray_and_catch_exceptions()
async def stop_job(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
@ -235,7 +230,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
@optional_utils.init_ray_and_catch_exceptions()
async def get_job_info(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
@ -250,7 +247,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
@routes.get("/api/jobs/")
@optional_utils.init_ray_and_catch_exceptions()
async def list_jobs(self, req: Request) -> Response:
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
driver_jobs, submission_job_drivers = await get_driver_jobs(
self._dashboard_head.gcs_aio_client
)
submission_jobs = await self._job_manager.list_jobs()
submission_jobs = [
@ -275,67 +274,13 @@ class JobHead(dashboard_utils.DashboardHeadModule):
content_type="application/json",
)
async def _get_driver_jobs(
self,
) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]:
"""Returns a tuple of dictionaries related to drivers.
The first dictionary contains all driver jobs and is keyed by the job's id.
The second dictionary contains drivers that belong to submission jobs.
It's keyed by the submission job's submission id.
Only the last driver of a submission job is returned.
"""
request = gcs_service_pb2.GetAllJobInfoRequest()
reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5)
jobs = {}
submission_job_drivers = {}
for job_table_entry in reply.job_info_list:
if job_table_entry.config.ray_namespace.startswith(
ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX
):
# Skip jobs in any _ray_internal_ namespace
continue
job_id = job_table_entry.job_id.hex()
metadata = dict(job_table_entry.config.metadata)
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
if not job_submission_id:
driver = DriverInfo(
id=job_id,
node_ip_address=job_table_entry.driver_ip_address,
pid=job_table_entry.driver_pid,
)
job = JobDetails(
job_id=job_id,
type=JobType.DRIVER,
status=JobStatus.SUCCEEDED
if job_table_entry.is_dead
else JobStatus.RUNNING,
entrypoint="",
start_time=job_table_entry.start_time,
end_time=job_table_entry.end_time,
metadata=metadata,
runtime_env=RuntimeEnv.deserialize(
job_table_entry.config.runtime_env_info.serialized_runtime_env
).to_dict(),
driver_info=driver,
)
jobs[job_id] = job
else:
driver = DriverInfo(
id=job_id,
node_ip_address=job_table_entry.driver_ip_address,
pid=job_table_entry.driver_pid,
)
submission_job_drivers[job_submission_id] = driver
return jobs, submission_job_drivers
@routes.get("/api/jobs/{job_or_submission_id}/logs")
@optional_utils.init_ray_and_catch_exceptions()
async def get_job_logs(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
@ -357,7 +302,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
@optional_utils.init_ray_and_catch_exceptions()
async def tail_job_logs(self, req: Request) -> Response:
job_or_submission_id = req.match_info["job_or_submission_id"]
job = await self.find_job_by_ids(job_or_submission_id)
job = await find_job_by_ids(
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
)
if not job:
return Response(
text=f"Job {job_or_submission_id} does not exist",
@ -380,10 +327,6 @@ class JobHead(dashboard_utils.DashboardHeadModule):
if not self._job_manager:
self._job_manager = JobManager(self._dashboard_head.gcs_aio_client)
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._dashboard_head.aiogrpc_gcs_channel
)
@staticmethod
def is_minimal_module():
return False

View file

@ -26,6 +26,7 @@ from ray.dashboard.modules.job.common import (
from ray.dashboard.modules.job.utils import file_tail_iterator
from ray.exceptions import RuntimeEnvSetupError
from ray.job_submission import JobStatus
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
logger = logging.getLogger(__name__)
@ -469,6 +470,7 @@ class JobManager:
runtime_env: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
_start_signal_actor: Optional[ActorHandle] = None,
_driver_on_current_node: bool = True,
) -> str:
"""
Job execution happens asynchronously.
@ -493,6 +495,8 @@ class JobManager:
_start_signal_actor: Used in testing only to capture state
transitions between PENDING -> RUNNING. Regular user shouldn't
need this.
_driver_on_current_node: whether force driver run on current node,
the default value is True.
Returns:
job_id: Generated uuid for further job management. Only valid
@ -517,15 +521,19 @@ class JobManager:
# returns immediately and we can catch errors with the actor starting
# up.
try:
scheduling_strategy = "DEFAULT"
if _driver_on_current_node:
# If JobManager is created by dashboard server
# running on headnode, same for job supervisor actors scheduled
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().node_id,
soft=False,
)
supervisor = self._supervisor_actor_cls.options(
lifetime="detached",
name=self.JOB_ACTOR_NAME_TEMPLATE.format(job_id=submission_id),
num_cpus=0,
# Currently we assume JobManager is created by dashboard server
# running on headnode, same for job supervisor actors scheduled
resources={
self._get_current_node_resource_key(): 0.001,
},
scheduling_strategy=scheduling_strategy,
runtime_env=self._get_supervisor_runtime_env(runtime_env),
).remote(submission_id, entrypoint, metadata or {}, self._gcs_address)
supervisor.run.remote(_start_signal_actor=_start_signal_actor)

View file

@ -0,0 +1,318 @@
import asyncio
import logging
import os
import shutil
import sys
import tempfile
import time
from pathlib import Path
import pytest
from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig
import yaml
from ray._private.runtime_env.packaging import Protocol, parse_uri
from ray._private.ray_constants import DEFAULT_DASHBOARD_AGENT_LISTEN_PORT
from ray._private.test_utils import (
chdir,
format_web_url,
wait_until_server_available,
)
from ray.dashboard.modules.job.common import JobSubmitRequest
from ray.dashboard.modules.job.utils import validate_request_type
from ray.dashboard.tests.conftest import * # noqa
from ray.job_submission import JobStatus
from ray.tests.conftest import _ray_start
from ray.dashboard.modules.job.job_head import JobAgentSubmissionClient
# This test requires you have AWS credentials set up (any AWS credentials will
# do, this test only accesses a public bucket).
logger = logging.getLogger(__name__)
DRIVER_SCRIPT_DIR = os.path.join(os.path.dirname(__file__), "subprocess_driver_scripts")
EVENT_LOOP = asyncio.get_event_loop()
@pytest.fixture
def job_sdk_client():
with _ray_start(include_dashboard=True, num_cpus=1) as ctx:
ip, port = ctx.address_info["webui_url"].split(":")
agent_address = f"{ip}:{DEFAULT_DASHBOARD_AGENT_LISTEN_PORT}"
assert wait_until_server_available(agent_address)
yield JobAgentSubmissionClient(format_web_url(agent_address))
async def _check_job(
client: JobAgentSubmissionClient, job_id: str, status: JobStatus, timeout: int = 10
) -> bool:
async def _check():
result = await client.get_job_info(job_id)
return result.status == status
st = time.time()
while time.time() <= timeout + st:
res = await _check()
if res:
return True
await asyncio.sleep(0.1)
return False
@pytest.fixture(
scope="module",
params=[
"no_working_dir",
"local_working_dir",
"s3_working_dir",
"local_py_modules",
"working_dir_and_local_py_modules_whl",
"local_working_dir_zip",
"pip_txt",
"conda_yaml",
"local_py_modules",
],
)
def runtime_env_option(request):
import_in_task_script = """
import ray
ray.init(address="auto")
@ray.remote
def f():
import pip_install_test
ray.get(f.remote())
"""
if request.param == "no_working_dir":
yield {
"runtime_env": {},
"entrypoint": "echo hello",
"expected_logs": "hello\n",
}
elif request.param in {
"local_working_dir",
"local_working_dir_zip",
"local_py_modules",
"working_dir_and_local_py_modules_whl",
}:
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir)
hello_file = path / "test.py"
with hello_file.open(mode="w") as f:
f.write("from test_module import run_test\n")
f.write("print(run_test())")
module_path = path / "test_module"
module_path.mkdir(parents=True)
test_file = module_path / "test.py"
with test_file.open(mode="w") as f:
f.write("def run_test():\n")
f.write(" return 'Hello from test_module!'\n") # noqa: Q000
init_file = module_path / "__init__.py"
with init_file.open(mode="w") as f:
f.write("from test_module.test import run_test\n")
if request.param == "local_working_dir":
yield {
"runtime_env": {"working_dir": tmp_dir},
"entrypoint": "python test.py",
"expected_logs": "Hello from test_module!\n",
}
elif request.param == "local_working_dir_zip":
local_zipped_dir = shutil.make_archive(
os.path.join(tmp_dir, "test"), "zip", tmp_dir
)
yield {
"runtime_env": {"working_dir": local_zipped_dir},
"entrypoint": "python test.py",
"expected_logs": "Hello from test_module!\n",
}
elif request.param == "local_py_modules":
yield {
"runtime_env": {"py_modules": [str(Path(tmp_dir) / "test_module")]},
"entrypoint": (
"python -c 'import test_module;"
"print(test_module.run_test())'"
),
"expected_logs": "Hello from test_module!\n",
}
elif request.param == "working_dir_and_local_py_modules_whl":
yield {
"runtime_env": {
"working_dir": "s3://runtime-env-test/script_runtime_env.zip",
"py_modules": [
Path(os.path.dirname(__file__))
/ "pip_install_test-0.5-py3-none-any.whl"
],
},
"entrypoint": (
"python script.py && python -c 'import pip_install_test'"
),
"expected_logs": (
"Executing main() from script.py !!\n"
"Good job! You installed a pip module."
),
}
else:
raise ValueError(f"Unexpected pytest fixture option {request.param}")
elif request.param == "s3_working_dir":
yield {
"runtime_env": {
"working_dir": "s3://runtime-env-test/script_runtime_env.zip",
},
"entrypoint": "python script.py",
"expected_logs": "Executing main() from script.py !!\n",
}
elif request.param == "pip_txt":
with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir):
pip_list = ["pip-install-test==0.5"]
relative_filepath = "requirements.txt"
pip_file = Path(relative_filepath)
pip_file.write_text("\n".join(pip_list))
runtime_env = {"pip": {"packages": relative_filepath, "pip_check": False}}
yield {
"runtime_env": runtime_env,
"entrypoint": (
f"python -c 'import pip_install_test' && "
f"python -c '{import_in_task_script}'"
),
"expected_logs": "Good job! You installed a pip module.",
}
elif request.param == "conda_yaml":
with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir):
conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]}
relative_filepath = "environment.yml"
conda_file = Path(relative_filepath)
conda_file.write_text(yaml.dump(conda_dict))
runtime_env = {"conda": relative_filepath}
yield {
"runtime_env": runtime_env,
"entrypoint": f"python -c '{import_in_task_script}'",
# TODO(architkulkarni): Uncomment after #22968 is fixed.
# "entrypoint": "python -c 'import pip_install_test'",
"expected_logs": "Good job! You installed a pip module.",
}
else:
assert False, f"Unrecognized option: {request.param}."
@pytest.mark.asyncio
async def test_submit_job(job_sdk_client, runtime_env_option, monkeypatch):
# This flag allows for local testing of runtime env conda functionality
# without needing a built Ray wheel. Rather than insert the link to the
# wheel into the conda spec, it links to the current Python site.
monkeypatch.setenv("RAY_RUNTIME_ENV_LOCAL_DEV_MODE", "1")
client = job_sdk_client
need_upload = False
working_dir = runtime_env_option["runtime_env"].get("working_dir", None)
py_modules = runtime_env_option["runtime_env"].get("py_modules", [])
def _need_upload(path):
try:
protocol, _ = parse_uri(path)
if protocol == Protocol.GCS:
return True
except ValueError:
# local file, need upload
return True
return False
if working_dir:
need_upload = need_upload or _need_upload(working_dir)
if py_modules:
need_upload = need_upload or any(
[_need_upload(str(py_module)) for py_module in py_modules]
)
# TODO(Catch-Bull): delete this after we implemented
# `upload package` and `get package`
if need_upload:
# not implemented `upload package` yet.
print("Skip test, because of need upload")
return
runtime_env = RuntimeEnv(**runtime_env_option["runtime_env"]).to_dict()
request = validate_request_type(
{"runtime_env": runtime_env, "entrypoint": runtime_env_option["entrypoint"]},
JobSubmitRequest,
)
submit_result = await client.submit_job_internal(request)
job_id = submit_result.submission_id
check_result = await _check_job(
client=client, job_id=job_id, status=JobStatus.SUCCEEDED, timeout=120
)
assert check_result
# TODO(Catch-Bull): delete this after we implemented
# `get_job_logs`
# not implemented `get_job_logs` yet.
print("Skip test, because of need get job logs")
return
logs = client.get_job_logs(job_id)
assert runtime_env_option["expected_logs"] in logs
@pytest.mark.asyncio
async def test_timeout(job_sdk_client):
client = job_sdk_client
runtime_env = RuntimeEnv(
pip={
"packages": ["tensorflow", "requests", "botocore", "torch"],
"pip_check": False,
"pip_version": "==22.0.2;python_version=='3.8.11'",
},
config=RuntimeEnvConfig(setup_timeout_seconds=1),
).to_dict()
request = validate_request_type(
{"runtime_env": runtime_env, "entrypoint": "echo hello"},
JobSubmitRequest,
)
submit_result = await client.submit_job_internal(request)
job_id = submit_result.submission_id
check_result = await _check_job(
client=client, job_id=job_id, status=JobStatus.FAILED, timeout=10
)
assert check_result
data = await client.get_job_info(job_id)
assert "Failed to set up runtime environment" in data.message
assert "Timeout" in data.message
assert "consider increasing `setup_timeout_seconds`" in data.message
@pytest.mark.asyncio
async def test_runtime_env_setup_failure(job_sdk_client):
client = job_sdk_client
runtime_env = RuntimeEnv(working_dir="s3://does_not_exist.zip").to_dict()
request = validate_request_type(
{"runtime_env": runtime_env, "entrypoint": "echo hello"},
JobSubmitRequest,
)
submit_result = await client.submit_job_internal(request)
job_id = submit_result.submission_id
check_result = await _check_job(
client=client, job_id=job_id, status=JobStatus.FAILED, timeout=10
)
assert check_result
data = await client.get_job_info(job_id)
assert "Failed to set up runtime environment" in data.message
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,6 +1,40 @@
import dataclasses
import logging
import os
from typing import Iterator, List, Optional
import traceback
from dataclasses import dataclass
from typing import Iterator, List, Optional, Any, Dict, Tuple
try:
# package `aiohttp` is not in ray's minimal dependencies
import aiohttp
from aiohttp.web import Request, Response
except Exception:
aiohttp = None
Request = None
Response = None
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray.dashboard.modules.job.common import validate_request_type
try:
# package `pydantic` is not in ray's minimal dependencies
from ray.dashboard.modules.job.pydantic_models import (
DriverInfo,
JobDetails,
JobType,
)
except Exception:
DriverInfo = None
JobDetails = None
JobType = None
from ray.dashboard.modules.job.common import (
JobStatus,
JOB_ID_METADATA_KEY,
)
from ray.runtime_env import RuntimeEnv
logger = logging.getLogger(__name__)
@ -64,3 +98,118 @@ def file_tail_iterator(path: str) -> Iterator[Optional[List[str]]]:
lines = []
chunk_char_count = 0
curr_line = None
async def parse_and_validate_request(req: Request, request_type: dataclass) -> Any:
"""Parse request and cast to request type. If parsing failed, return a
Response object with status 400 and stacktrace instead.
"""
import aiohttp
try:
return validate_request_type(await req.json(), request_type)
except Exception as e:
logger.info(f"Got invalid request type: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPBadRequest.status_code,
)
async def get_driver_jobs(
gcs_aio_client: GcsAioClient,
) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]:
"""Returns a tuple of dictionaries related to drivers.
The first dictionary contains all driver jobs and is keyed by the job's id.
The second dictionary contains drivers that belong to submission jobs.
It's keyed by the submission job's submission id.
Only the last driver of a submission job is returned.
"""
reply = await gcs_aio_client.get_all_job_info()
jobs = {}
submission_job_drivers = {}
for job_table_entry in reply.job_info_list:
if job_table_entry.config.ray_namespace.startswith(
ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX
):
# Skip jobs in any _ray_internal_ namespace
continue
job_id = job_table_entry.job_id.hex()
metadata = dict(job_table_entry.config.metadata)
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
if not job_submission_id:
driver = DriverInfo(
id=job_id,
node_ip_address=job_table_entry.driver_ip_address,
pid=job_table_entry.driver_pid,
)
job = JobDetails(
job_id=job_id,
type=JobType.DRIVER,
status=JobStatus.SUCCEEDED
if job_table_entry.is_dead
else JobStatus.RUNNING,
entrypoint="",
start_time=job_table_entry.start_time,
end_time=job_table_entry.end_time,
metadata=metadata,
runtime_env=RuntimeEnv.deserialize(
job_table_entry.config.runtime_env_info.serialized_runtime_env
).to_dict(),
driver_info=driver,
)
jobs[job_id] = job
else:
driver = DriverInfo(
id=job_id,
node_ip_address=job_table_entry.driver_ip_address,
pid=job_table_entry.driver_pid,
)
submission_job_drivers[job_submission_id] = driver
return jobs, submission_job_drivers
async def find_job_by_ids(
gcs_aio_client: GcsAioClient,
job_manager: "JobManager", # noqa: F821
job_or_submission_id: str,
) -> Optional[JobDetails]:
"""
Attempts to find the job with a given submission_id or job id.
"""
# First try to find by job_id
driver_jobs, submission_job_drivers = await get_driver_jobs(gcs_aio_client)
job = driver_jobs.get(job_or_submission_id)
if job:
return job
# Try to find a driver with the given id
submission_id = next(
(
id
for id, driver in submission_job_drivers.items()
if driver.id == job_or_submission_id
),
None,
)
if not submission_id:
# If we didn't find a driver with the given id,
# then lets try to search for a submission with given id
submission_id = job_or_submission_id
job_info = await job_manager.get_job_info(submission_id)
if job_info:
driver = submission_job_drivers.get(submission_id)
job = JobDetails(
**dataclasses.asdict(job_info),
submission_id=submission_id,
job_id=driver.id if driver else None,
driver_info=driver,
type=JobType.SUBMISSION,
)
return job
return None

View file

@ -415,6 +415,9 @@ class GcsAioClient:
self._heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
self._channel.channel()
)
self._job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._channel.channel()
)
@_auto_reconnect
async def check_alive(
@ -521,6 +524,14 @@ class GcsAioClient:
f"due to error {reply.status.message}"
)
@_auto_reconnect
async def get_all_job_info(
self, timeout: Optional[float] = None
) -> gcs_service_pb2.GetAllJobInfoReply:
req = gcs_service_pb2.GetAllJobInfoRequest()
reply = await self._job_info_stub.GetAllJobInfo(req, timeout=timeout)
return reply
def use_gcs_for_bootstrap():
"""In the current version of Ray, we always use the GCS to bootstrap.