mirror of
https://github.com/vale981/ray
synced 2025-03-04 09:31:43 -05:00
[Job Submission][refactor 2/N] introduce job agent (#28203)
This commit is contained in:
parent
a31be7cef1
commit
ce70b8b96e
6 changed files with 661 additions and 135 deletions
97
dashboard/modules/job/job_agent.py
Normal file
97
dashboard/modules/job/job_agent.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import aiohttp
|
||||
from aiohttp.web import Request, Response
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import ray.dashboard.optional_utils as optional_utils
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
from ray.dashboard.modules.job.common import (
|
||||
JobSubmitRequest,
|
||||
JobSubmitResponse,
|
||||
)
|
||||
from ray.dashboard.modules.job.job_manager import JobManager
|
||||
from ray.dashboard.modules.job.utils import parse_and_validate_request, find_job_by_ids
|
||||
|
||||
|
||||
routes = optional_utils.ClassMethodRouteTable
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobAgent(dashboard_utils.DashboardAgentModule):
|
||||
def __init__(self, dashboard_agent):
|
||||
super().__init__(dashboard_agent)
|
||||
self._job_manager = None
|
||||
self._gcs_job_info_stub = None
|
||||
|
||||
@routes.post("/api/job_agent/jobs/")
|
||||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def submit_job(self, req: Request) -> Response:
|
||||
result = await parse_and_validate_request(req, JobSubmitRequest)
|
||||
# Request parsing failed, returned with Response object.
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
else:
|
||||
submit_request = result
|
||||
|
||||
request_submission_id = submit_request.submission_id or submit_request.job_id
|
||||
try:
|
||||
submission_id = await self.get_job_manager().submit_job(
|
||||
entrypoint=submit_request.entrypoint,
|
||||
submission_id=request_submission_id,
|
||||
runtime_env=submit_request.runtime_env,
|
||||
metadata=submit_request.metadata,
|
||||
_driver_on_current_node=False,
|
||||
)
|
||||
|
||||
resp = JobSubmitResponse(job_id=submission_id, submission_id=submission_id)
|
||||
except (TypeError, ValueError):
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPInternalServerError.status_code,
|
||||
)
|
||||
|
||||
return Response(
|
||||
text=json.dumps(dataclasses.asdict(resp)),
|
||||
content_type="application/json",
|
||||
status=aiohttp.web.HTTPOk.status_code,
|
||||
)
|
||||
|
||||
@routes.get("/api/job_agent/jobs/{job_or_submission_id}")
|
||||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def get_job_info(self, req: Request) -> Response:
|
||||
job_or_submission_id = req.match_info["job_or_submission_id"]
|
||||
|
||||
job = await find_job_by_ids(
|
||||
self._dashboard_agent.gcs_aio_client,
|
||||
self.get_job_manager(),
|
||||
job_or_submission_id,
|
||||
)
|
||||
if not job:
|
||||
return Response(
|
||||
text=f"Job {job_or_submission_id} does not exist",
|
||||
status=aiohttp.web.HTTPNotFound.status_code,
|
||||
)
|
||||
|
||||
return Response(
|
||||
text=json.dumps(job.dict()),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
def get_job_manager(self):
|
||||
if not self._job_manager:
|
||||
self._job_manager = JobManager(self._dashboard_agent.gcs_aio_client)
|
||||
return self._job_manager
|
||||
|
||||
async def run(self, server):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
return False
|
|
@ -2,14 +2,12 @@ import dataclasses
|
|||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp.web
|
||||
from aiohttp.web import Request, Response
|
||||
from aiohttp.client import ClientResponse
|
||||
|
||||
import ray
|
||||
from ray._private import ray_constants
|
||||
import ray.dashboard.optional_utils as optional_utils
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
from ray._private.runtime_env.packaging import (
|
||||
|
@ -17,28 +15,27 @@ from ray._private.runtime_env.packaging import (
|
|||
pin_runtime_env_uri,
|
||||
upload_package_to_gcs,
|
||||
)
|
||||
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
|
||||
from ray.dashboard.modules.job.common import (
|
||||
http_uri_components_to_uri,
|
||||
JobStatus,
|
||||
JobSubmitRequest,
|
||||
JobSubmitResponse,
|
||||
JobStopResponse,
|
||||
JobLogsResponse,
|
||||
validate_request_type,
|
||||
JOB_ID_METADATA_KEY,
|
||||
)
|
||||
from ray.dashboard.modules.job.pydantic_models import (
|
||||
DriverInfo,
|
||||
JobDetails,
|
||||
JobType,
|
||||
)
|
||||
from ray.dashboard.modules.job.utils import (
|
||||
parse_and_validate_request,
|
||||
get_driver_jobs,
|
||||
find_job_by_ids,
|
||||
)
|
||||
from ray.dashboard.modules.version import (
|
||||
CURRENT_VERSION,
|
||||
VersionResponse,
|
||||
)
|
||||
from ray.dashboard.modules.job.job_manager import JobManager
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -46,65 +43,61 @@ logger.setLevel(logging.INFO)
|
|||
routes = optional_utils.ClassMethodRouteTable
|
||||
|
||||
|
||||
class JobAgentSubmissionClient:
|
||||
"""A local client for submitting and interacting with jobs on a specific node
|
||||
in the remote cluster.
|
||||
Submits requests over HTTP to the job agent on the specific node using the REST API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dashboard_agent_address: str,
|
||||
):
|
||||
self._address = dashboard_agent_address
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
async def _raise_error(self, r: ClientResponse):
|
||||
status = r.status
|
||||
error_text = await r.text()
|
||||
raise RuntimeError(f"Request failed with status code {status}: {error_text}.")
|
||||
|
||||
async def submit_job_internal(self, req: JobSubmitRequest) -> JobSubmitResponse:
|
||||
|
||||
logger.debug(f"Submitting job with submission_id={req.submission_id}.")
|
||||
|
||||
async with self._session.post(
|
||||
self._address + "/api/job_agent/jobs/", json=dataclasses.asdict(req)
|
||||
) as resp:
|
||||
|
||||
if resp.status == 200:
|
||||
result_json = await resp.json()
|
||||
return JobSubmitResponse(**result_json)
|
||||
else:
|
||||
await self._raise_error(resp)
|
||||
|
||||
async def get_job_info(self, job_id: str) -> JobDetails:
|
||||
async with self._session.get(
|
||||
self._address + f"/api/job_agent/jobs/{job_id}"
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
result_json = await resp.json()
|
||||
return JobDetails(**result_json)
|
||||
else:
|
||||
await self._raise_error(resp)
|
||||
|
||||
async def close(self, ignore_error=True):
|
||||
try:
|
||||
await self._session.close()
|
||||
except Exception:
|
||||
if not ignore_error:
|
||||
raise
|
||||
|
||||
|
||||
class JobHead(dashboard_utils.DashboardHeadModule):
|
||||
def __init__(self, dashboard_head):
|
||||
super().__init__(dashboard_head)
|
||||
self._dashboard_head = dashboard_head
|
||||
self._job_manager = None
|
||||
self._gcs_job_info_stub = None
|
||||
|
||||
async def _parse_and_validate_request(
|
||||
self, req: Request, request_type: dataclass
|
||||
) -> Any:
|
||||
"""Parse request and cast to request type. If parsing failed, return a
|
||||
Response object with status 400 and stacktrace instead.
|
||||
"""
|
||||
try:
|
||||
return validate_request_type(await req.json(), request_type)
|
||||
except Exception as e:
|
||||
logger.info(f"Got invalid request type: {e}")
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||
)
|
||||
|
||||
async def find_job_by_ids(self, job_or_submission_id: str) -> Optional[JobDetails]:
|
||||
"""
|
||||
Attempts to find the job with a given submission_id or job id.
|
||||
"""
|
||||
# First try to find by job_id
|
||||
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
|
||||
job = driver_jobs.get(job_or_submission_id)
|
||||
if job:
|
||||
return job
|
||||
# Try to find a driver with the given id
|
||||
submission_id = next(
|
||||
(
|
||||
id
|
||||
for id, driver in submission_job_drivers.items()
|
||||
if driver.id == job_or_submission_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not submission_id:
|
||||
# If we didn't find a driver with the given id,
|
||||
# then lets try to search for a submission with given id
|
||||
submission_id = job_or_submission_id
|
||||
|
||||
job_info = await self._job_manager.get_job_info(submission_id)
|
||||
if job_info:
|
||||
driver = submission_job_drivers.get(submission_id)
|
||||
job = JobDetails(
|
||||
**dataclasses.asdict(job_info),
|
||||
submission_id=submission_id,
|
||||
job_id=driver.id if driver else None,
|
||||
driver_info=driver,
|
||||
type=JobType.SUBMISSION,
|
||||
)
|
||||
return job
|
||||
|
||||
return None
|
||||
|
||||
@routes.get("/api/version")
|
||||
async def get_version(self, req: Request) -> Response:
|
||||
|
@ -167,7 +160,7 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
@routes.post("/api/jobs/")
|
||||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def submit_job(self, req: Request) -> Response:
|
||||
result = await self._parse_and_validate_request(req, JobSubmitRequest)
|
||||
result = await parse_and_validate_request(req, JobSubmitRequest)
|
||||
# Request parsing failed, returned with Response object.
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
|
@ -206,7 +199,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def stop_job(self, req: Request) -> Response:
|
||||
job_or_submission_id = req.match_info["job_or_submission_id"]
|
||||
job = await self.find_job_by_ids(job_or_submission_id)
|
||||
job = await find_job_by_ids(
|
||||
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
|
||||
)
|
||||
if not job:
|
||||
return Response(
|
||||
text=f"Job {job_or_submission_id} does not exist",
|
||||
|
@ -235,7 +230,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def get_job_info(self, req: Request) -> Response:
|
||||
job_or_submission_id = req.match_info["job_or_submission_id"]
|
||||
job = await self.find_job_by_ids(job_or_submission_id)
|
||||
job = await find_job_by_ids(
|
||||
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
|
||||
)
|
||||
if not job:
|
||||
return Response(
|
||||
text=f"Job {job_or_submission_id} does not exist",
|
||||
|
@ -250,7 +247,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
@routes.get("/api/jobs/")
|
||||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def list_jobs(self, req: Request) -> Response:
|
||||
driver_jobs, submission_job_drivers = await self._get_driver_jobs()
|
||||
driver_jobs, submission_job_drivers = await get_driver_jobs(
|
||||
self._dashboard_head.gcs_aio_client
|
||||
)
|
||||
|
||||
submission_jobs = await self._job_manager.list_jobs()
|
||||
submission_jobs = [
|
||||
|
@ -275,67 +274,13 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
content_type="application/json",
|
||||
)
|
||||
|
||||
async def _get_driver_jobs(
|
||||
self,
|
||||
) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]:
|
||||
"""Returns a tuple of dictionaries related to drivers.
|
||||
|
||||
The first dictionary contains all driver jobs and is keyed by the job's id.
|
||||
The second dictionary contains drivers that belong to submission jobs.
|
||||
It's keyed by the submission job's submission id.
|
||||
Only the last driver of a submission job is returned.
|
||||
"""
|
||||
request = gcs_service_pb2.GetAllJobInfoRequest()
|
||||
reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5)
|
||||
|
||||
jobs = {}
|
||||
submission_job_drivers = {}
|
||||
for job_table_entry in reply.job_info_list:
|
||||
if job_table_entry.config.ray_namespace.startswith(
|
||||
ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX
|
||||
):
|
||||
# Skip jobs in any _ray_internal_ namespace
|
||||
continue
|
||||
job_id = job_table_entry.job_id.hex()
|
||||
metadata = dict(job_table_entry.config.metadata)
|
||||
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
|
||||
if not job_submission_id:
|
||||
driver = DriverInfo(
|
||||
id=job_id,
|
||||
node_ip_address=job_table_entry.driver_ip_address,
|
||||
pid=job_table_entry.driver_pid,
|
||||
)
|
||||
job = JobDetails(
|
||||
job_id=job_id,
|
||||
type=JobType.DRIVER,
|
||||
status=JobStatus.SUCCEEDED
|
||||
if job_table_entry.is_dead
|
||||
else JobStatus.RUNNING,
|
||||
entrypoint="",
|
||||
start_time=job_table_entry.start_time,
|
||||
end_time=job_table_entry.end_time,
|
||||
metadata=metadata,
|
||||
runtime_env=RuntimeEnv.deserialize(
|
||||
job_table_entry.config.runtime_env_info.serialized_runtime_env
|
||||
).to_dict(),
|
||||
driver_info=driver,
|
||||
)
|
||||
jobs[job_id] = job
|
||||
else:
|
||||
driver = DriverInfo(
|
||||
id=job_id,
|
||||
node_ip_address=job_table_entry.driver_ip_address,
|
||||
pid=job_table_entry.driver_pid,
|
||||
)
|
||||
submission_job_drivers[job_submission_id] = driver
|
||||
|
||||
return jobs, submission_job_drivers
|
||||
|
||||
@routes.get("/api/jobs/{job_or_submission_id}/logs")
|
||||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def get_job_logs(self, req: Request) -> Response:
|
||||
job_or_submission_id = req.match_info["job_or_submission_id"]
|
||||
job = await self.find_job_by_ids(job_or_submission_id)
|
||||
job = await find_job_by_ids(
|
||||
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
|
||||
)
|
||||
if not job:
|
||||
return Response(
|
||||
text=f"Job {job_or_submission_id} does not exist",
|
||||
|
@ -357,7 +302,9 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
@optional_utils.init_ray_and_catch_exceptions()
|
||||
async def tail_job_logs(self, req: Request) -> Response:
|
||||
job_or_submission_id = req.match_info["job_or_submission_id"]
|
||||
job = await self.find_job_by_ids(job_or_submission_id)
|
||||
job = await find_job_by_ids(
|
||||
self._dashboard_head.gcs_aio_client, self._job_manager, job_or_submission_id
|
||||
)
|
||||
if not job:
|
||||
return Response(
|
||||
text=f"Job {job_or_submission_id} does not exist",
|
||||
|
@ -380,10 +327,6 @@ class JobHead(dashboard_utils.DashboardHeadModule):
|
|||
if not self._job_manager:
|
||||
self._job_manager = JobManager(self._dashboard_head.gcs_aio_client)
|
||||
|
||||
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
||||
self._dashboard_head.aiogrpc_gcs_channel
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_minimal_module():
|
||||
return False
|
||||
|
|
|
@ -26,6 +26,7 @@ from ray.dashboard.modules.job.common import (
|
|||
from ray.dashboard.modules.job.utils import file_tail_iterator
|
||||
from ray.exceptions import RuntimeEnvSetupError
|
||||
from ray.job_submission import JobStatus
|
||||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -469,6 +470,7 @@ class JobManager:
|
|||
runtime_env: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
_start_signal_actor: Optional[ActorHandle] = None,
|
||||
_driver_on_current_node: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Job execution happens asynchronously.
|
||||
|
@ -493,6 +495,8 @@ class JobManager:
|
|||
_start_signal_actor: Used in testing only to capture state
|
||||
transitions between PENDING -> RUNNING. Regular user shouldn't
|
||||
need this.
|
||||
_driver_on_current_node: whether force driver run on current node,
|
||||
the default value is True.
|
||||
|
||||
Returns:
|
||||
job_id: Generated uuid for further job management. Only valid
|
||||
|
@ -517,15 +521,19 @@ class JobManager:
|
|||
# returns immediately and we can catch errors with the actor starting
|
||||
# up.
|
||||
try:
|
||||
scheduling_strategy = "DEFAULT"
|
||||
if _driver_on_current_node:
|
||||
# If JobManager is created by dashboard server
|
||||
# running on headnode, same for job supervisor actors scheduled
|
||||
scheduling_strategy = NodeAffinitySchedulingStrategy(
|
||||
node_id=ray.get_runtime_context().node_id,
|
||||
soft=False,
|
||||
)
|
||||
supervisor = self._supervisor_actor_cls.options(
|
||||
lifetime="detached",
|
||||
name=self.JOB_ACTOR_NAME_TEMPLATE.format(job_id=submission_id),
|
||||
num_cpus=0,
|
||||
# Currently we assume JobManager is created by dashboard server
|
||||
# running on headnode, same for job supervisor actors scheduled
|
||||
resources={
|
||||
self._get_current_node_resource_key(): 0.001,
|
||||
},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
runtime_env=self._get_supervisor_runtime_env(runtime_env),
|
||||
).remote(submission_id, entrypoint, metadata or {}, self._gcs_address)
|
||||
supervisor.run.remote(_start_signal_actor=_start_signal_actor)
|
||||
|
|
318
dashboard/modules/job/tests/test_job_agent.py
Normal file
318
dashboard/modules/job/tests/test_job_agent.py
Normal file
|
@ -0,0 +1,318 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig
|
||||
import yaml
|
||||
|
||||
from ray._private.runtime_env.packaging import Protocol, parse_uri
|
||||
from ray._private.ray_constants import DEFAULT_DASHBOARD_AGENT_LISTEN_PORT
|
||||
from ray._private.test_utils import (
|
||||
chdir,
|
||||
format_web_url,
|
||||
wait_until_server_available,
|
||||
)
|
||||
from ray.dashboard.modules.job.common import JobSubmitRequest
|
||||
from ray.dashboard.modules.job.utils import validate_request_type
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray.job_submission import JobStatus
|
||||
from ray.tests.conftest import _ray_start
|
||||
from ray.dashboard.modules.job.job_head import JobAgentSubmissionClient
|
||||
|
||||
# This test requires you have AWS credentials set up (any AWS credentials will
|
||||
# do, this test only accesses a public bucket).
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DRIVER_SCRIPT_DIR = os.path.join(os.path.dirname(__file__), "subprocess_driver_scripts")
|
||||
EVENT_LOOP = asyncio.get_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_sdk_client():
|
||||
with _ray_start(include_dashboard=True, num_cpus=1) as ctx:
|
||||
ip, port = ctx.address_info["webui_url"].split(":")
|
||||
agent_address = f"{ip}:{DEFAULT_DASHBOARD_AGENT_LISTEN_PORT}"
|
||||
assert wait_until_server_available(agent_address)
|
||||
yield JobAgentSubmissionClient(format_web_url(agent_address))
|
||||
|
||||
|
||||
async def _check_job(
|
||||
client: JobAgentSubmissionClient, job_id: str, status: JobStatus, timeout: int = 10
|
||||
) -> bool:
|
||||
async def _check():
|
||||
result = await client.get_job_info(job_id)
|
||||
return result.status == status
|
||||
|
||||
st = time.time()
|
||||
while time.time() <= timeout + st:
|
||||
res = await _check()
|
||||
if res:
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[
|
||||
"no_working_dir",
|
||||
"local_working_dir",
|
||||
"s3_working_dir",
|
||||
"local_py_modules",
|
||||
"working_dir_and_local_py_modules_whl",
|
||||
"local_working_dir_zip",
|
||||
"pip_txt",
|
||||
"conda_yaml",
|
||||
"local_py_modules",
|
||||
],
|
||||
)
|
||||
def runtime_env_option(request):
|
||||
import_in_task_script = """
|
||||
import ray
|
||||
ray.init(address="auto")
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
import pip_install_test
|
||||
|
||||
ray.get(f.remote())
|
||||
"""
|
||||
if request.param == "no_working_dir":
|
||||
yield {
|
||||
"runtime_env": {},
|
||||
"entrypoint": "echo hello",
|
||||
"expected_logs": "hello\n",
|
||||
}
|
||||
elif request.param in {
|
||||
"local_working_dir",
|
||||
"local_working_dir_zip",
|
||||
"local_py_modules",
|
||||
"working_dir_and_local_py_modules_whl",
|
||||
}:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = Path(tmp_dir)
|
||||
|
||||
hello_file = path / "test.py"
|
||||
with hello_file.open(mode="w") as f:
|
||||
f.write("from test_module import run_test\n")
|
||||
f.write("print(run_test())")
|
||||
|
||||
module_path = path / "test_module"
|
||||
module_path.mkdir(parents=True)
|
||||
|
||||
test_file = module_path / "test.py"
|
||||
with test_file.open(mode="w") as f:
|
||||
f.write("def run_test():\n")
|
||||
f.write(" return 'Hello from test_module!'\n") # noqa: Q000
|
||||
|
||||
init_file = module_path / "__init__.py"
|
||||
with init_file.open(mode="w") as f:
|
||||
f.write("from test_module.test import run_test\n")
|
||||
|
||||
if request.param == "local_working_dir":
|
||||
yield {
|
||||
"runtime_env": {"working_dir": tmp_dir},
|
||||
"entrypoint": "python test.py",
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
elif request.param == "local_working_dir_zip":
|
||||
local_zipped_dir = shutil.make_archive(
|
||||
os.path.join(tmp_dir, "test"), "zip", tmp_dir
|
||||
)
|
||||
yield {
|
||||
"runtime_env": {"working_dir": local_zipped_dir},
|
||||
"entrypoint": "python test.py",
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
elif request.param == "local_py_modules":
|
||||
yield {
|
||||
"runtime_env": {"py_modules": [str(Path(tmp_dir) / "test_module")]},
|
||||
"entrypoint": (
|
||||
"python -c 'import test_module;"
|
||||
"print(test_module.run_test())'"
|
||||
),
|
||||
"expected_logs": "Hello from test_module!\n",
|
||||
}
|
||||
elif request.param == "working_dir_and_local_py_modules_whl":
|
||||
yield {
|
||||
"runtime_env": {
|
||||
"working_dir": "s3://runtime-env-test/script_runtime_env.zip",
|
||||
"py_modules": [
|
||||
Path(os.path.dirname(__file__))
|
||||
/ "pip_install_test-0.5-py3-none-any.whl"
|
||||
],
|
||||
},
|
||||
"entrypoint": (
|
||||
"python script.py && python -c 'import pip_install_test'"
|
||||
),
|
||||
"expected_logs": (
|
||||
"Executing main() from script.py !!\n"
|
||||
"Good job! You installed a pip module."
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unexpected pytest fixture option {request.param}")
|
||||
elif request.param == "s3_working_dir":
|
||||
yield {
|
||||
"runtime_env": {
|
||||
"working_dir": "s3://runtime-env-test/script_runtime_env.zip",
|
||||
},
|
||||
"entrypoint": "python script.py",
|
||||
"expected_logs": "Executing main() from script.py !!\n",
|
||||
}
|
||||
elif request.param == "pip_txt":
|
||||
with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir):
|
||||
pip_list = ["pip-install-test==0.5"]
|
||||
relative_filepath = "requirements.txt"
|
||||
pip_file = Path(relative_filepath)
|
||||
pip_file.write_text("\n".join(pip_list))
|
||||
runtime_env = {"pip": {"packages": relative_filepath, "pip_check": False}}
|
||||
yield {
|
||||
"runtime_env": runtime_env,
|
||||
"entrypoint": (
|
||||
f"python -c 'import pip_install_test' && "
|
||||
f"python -c '{import_in_task_script}'"
|
||||
),
|
||||
"expected_logs": "Good job! You installed a pip module.",
|
||||
}
|
||||
elif request.param == "conda_yaml":
|
||||
with tempfile.TemporaryDirectory() as tmpdir, chdir(tmpdir):
|
||||
conda_dict = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]}
|
||||
relative_filepath = "environment.yml"
|
||||
conda_file = Path(relative_filepath)
|
||||
conda_file.write_text(yaml.dump(conda_dict))
|
||||
runtime_env = {"conda": relative_filepath}
|
||||
|
||||
yield {
|
||||
"runtime_env": runtime_env,
|
||||
"entrypoint": f"python -c '{import_in_task_script}'",
|
||||
# TODO(architkulkarni): Uncomment after #22968 is fixed.
|
||||
# "entrypoint": "python -c 'import pip_install_test'",
|
||||
"expected_logs": "Good job! You installed a pip module.",
|
||||
}
|
||||
else:
|
||||
assert False, f"Unrecognized option: {request.param}."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_job(job_sdk_client, runtime_env_option, monkeypatch):
|
||||
# This flag allows for local testing of runtime env conda functionality
|
||||
# without needing a built Ray wheel. Rather than insert the link to the
|
||||
# wheel into the conda spec, it links to the current Python site.
|
||||
monkeypatch.setenv("RAY_RUNTIME_ENV_LOCAL_DEV_MODE", "1")
|
||||
|
||||
client = job_sdk_client
|
||||
|
||||
need_upload = False
|
||||
working_dir = runtime_env_option["runtime_env"].get("working_dir", None)
|
||||
py_modules = runtime_env_option["runtime_env"].get("py_modules", [])
|
||||
|
||||
def _need_upload(path):
|
||||
try:
|
||||
protocol, _ = parse_uri(path)
|
||||
if protocol == Protocol.GCS:
|
||||
return True
|
||||
except ValueError:
|
||||
# local file, need upload
|
||||
return True
|
||||
return False
|
||||
|
||||
if working_dir:
|
||||
need_upload = need_upload or _need_upload(working_dir)
|
||||
if py_modules:
|
||||
need_upload = need_upload or any(
|
||||
[_need_upload(str(py_module)) for py_module in py_modules]
|
||||
)
|
||||
|
||||
# TODO(Catch-Bull): delete this after we implemented
|
||||
# `upload package` and `get package`
|
||||
if need_upload:
|
||||
# not implemented `upload package` yet.
|
||||
print("Skip test, because of need upload")
|
||||
return
|
||||
|
||||
runtime_env = RuntimeEnv(**runtime_env_option["runtime_env"]).to_dict()
|
||||
request = validate_request_type(
|
||||
{"runtime_env": runtime_env, "entrypoint": runtime_env_option["entrypoint"]},
|
||||
JobSubmitRequest,
|
||||
)
|
||||
|
||||
submit_result = await client.submit_job_internal(request)
|
||||
job_id = submit_result.submission_id
|
||||
|
||||
check_result = await _check_job(
|
||||
client=client, job_id=job_id, status=JobStatus.SUCCEEDED, timeout=120
|
||||
)
|
||||
assert check_result
|
||||
|
||||
# TODO(Catch-Bull): delete this after we implemented
|
||||
# `get_job_logs`
|
||||
# not implemented `get_job_logs` yet.
|
||||
print("Skip test, because of need get job logs")
|
||||
return
|
||||
logs = client.get_job_logs(job_id)
|
||||
assert runtime_env_option["expected_logs"] in logs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout(job_sdk_client):
|
||||
client = job_sdk_client
|
||||
|
||||
runtime_env = RuntimeEnv(
|
||||
pip={
|
||||
"packages": ["tensorflow", "requests", "botocore", "torch"],
|
||||
"pip_check": False,
|
||||
"pip_version": "==22.0.2;python_version=='3.8.11'",
|
||||
},
|
||||
config=RuntimeEnvConfig(setup_timeout_seconds=1),
|
||||
).to_dict()
|
||||
request = validate_request_type(
|
||||
{"runtime_env": runtime_env, "entrypoint": "echo hello"},
|
||||
JobSubmitRequest,
|
||||
)
|
||||
|
||||
submit_result = await client.submit_job_internal(request)
|
||||
job_id = submit_result.submission_id
|
||||
|
||||
check_result = await _check_job(
|
||||
client=client, job_id=job_id, status=JobStatus.FAILED, timeout=10
|
||||
)
|
||||
assert check_result
|
||||
|
||||
data = await client.get_job_info(job_id)
|
||||
assert "Failed to set up runtime environment" in data.message
|
||||
assert "Timeout" in data.message
|
||||
assert "consider increasing `setup_timeout_seconds`" in data.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_env_setup_failure(job_sdk_client):
|
||||
client = job_sdk_client
|
||||
|
||||
runtime_env = RuntimeEnv(working_dir="s3://does_not_exist.zip").to_dict()
|
||||
request = validate_request_type(
|
||||
{"runtime_env": runtime_env, "entrypoint": "echo hello"},
|
||||
JobSubmitRequest,
|
||||
)
|
||||
|
||||
submit_result = await client.submit_job_internal(request)
|
||||
job_id = submit_result.submission_id
|
||||
|
||||
check_result = await _check_job(
|
||||
client=client, job_id=job_id, status=JobStatus.FAILED, timeout=10
|
||||
)
|
||||
assert check_result
|
||||
|
||||
data = await client.get_job_info(job_id)
|
||||
assert "Failed to set up runtime environment" in data.message
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,6 +1,40 @@
|
|||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from typing import Iterator, List, Optional
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, List, Optional, Any, Dict, Tuple
|
||||
|
||||
try:
|
||||
# package `aiohttp` is not in ray's minimal dependencies
|
||||
import aiohttp
|
||||
from aiohttp.web import Request, Response
|
||||
except Exception:
|
||||
aiohttp = None
|
||||
Request = None
|
||||
Response = None
|
||||
|
||||
from ray._private import ray_constants
|
||||
from ray._private.gcs_utils import GcsAioClient
|
||||
from ray.dashboard.modules.job.common import validate_request_type
|
||||
|
||||
try:
|
||||
# package `pydantic` is not in ray's minimal dependencies
|
||||
from ray.dashboard.modules.job.pydantic_models import (
|
||||
DriverInfo,
|
||||
JobDetails,
|
||||
JobType,
|
||||
)
|
||||
except Exception:
|
||||
DriverInfo = None
|
||||
JobDetails = None
|
||||
JobType = None
|
||||
|
||||
from ray.dashboard.modules.job.common import (
|
||||
JobStatus,
|
||||
JOB_ID_METADATA_KEY,
|
||||
)
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -64,3 +98,118 @@ def file_tail_iterator(path: str) -> Iterator[Optional[List[str]]]:
|
|||
lines = []
|
||||
chunk_char_count = 0
|
||||
curr_line = None
|
||||
|
||||
|
||||
async def parse_and_validate_request(req: Request, request_type: dataclass) -> Any:
|
||||
"""Parse request and cast to request type. If parsing failed, return a
|
||||
Response object with status 400 and stacktrace instead.
|
||||
"""
|
||||
import aiohttp
|
||||
|
||||
try:
|
||||
return validate_request_type(await req.json(), request_type)
|
||||
except Exception as e:
|
||||
logger.info(f"Got invalid request type: {e}")
|
||||
return Response(
|
||||
text=traceback.format_exc(),
|
||||
status=aiohttp.web.HTTPBadRequest.status_code,
|
||||
)
|
||||
|
||||
|
||||
async def get_driver_jobs(
|
||||
gcs_aio_client: GcsAioClient,
|
||||
) -> Tuple[Dict[str, JobDetails], Dict[str, DriverInfo]]:
|
||||
"""Returns a tuple of dictionaries related to drivers.
|
||||
|
||||
The first dictionary contains all driver jobs and is keyed by the job's id.
|
||||
The second dictionary contains drivers that belong to submission jobs.
|
||||
It's keyed by the submission job's submission id.
|
||||
Only the last driver of a submission job is returned.
|
||||
"""
|
||||
reply = await gcs_aio_client.get_all_job_info()
|
||||
|
||||
jobs = {}
|
||||
submission_job_drivers = {}
|
||||
for job_table_entry in reply.job_info_list:
|
||||
if job_table_entry.config.ray_namespace.startswith(
|
||||
ray_constants.RAY_INTERNAL_NAMESPACE_PREFIX
|
||||
):
|
||||
# Skip jobs in any _ray_internal_ namespace
|
||||
continue
|
||||
job_id = job_table_entry.job_id.hex()
|
||||
metadata = dict(job_table_entry.config.metadata)
|
||||
job_submission_id = metadata.get(JOB_ID_METADATA_KEY)
|
||||
if not job_submission_id:
|
||||
driver = DriverInfo(
|
||||
id=job_id,
|
||||
node_ip_address=job_table_entry.driver_ip_address,
|
||||
pid=job_table_entry.driver_pid,
|
||||
)
|
||||
job = JobDetails(
|
||||
job_id=job_id,
|
||||
type=JobType.DRIVER,
|
||||
status=JobStatus.SUCCEEDED
|
||||
if job_table_entry.is_dead
|
||||
else JobStatus.RUNNING,
|
||||
entrypoint="",
|
||||
start_time=job_table_entry.start_time,
|
||||
end_time=job_table_entry.end_time,
|
||||
metadata=metadata,
|
||||
runtime_env=RuntimeEnv.deserialize(
|
||||
job_table_entry.config.runtime_env_info.serialized_runtime_env
|
||||
).to_dict(),
|
||||
driver_info=driver,
|
||||
)
|
||||
jobs[job_id] = job
|
||||
else:
|
||||
driver = DriverInfo(
|
||||
id=job_id,
|
||||
node_ip_address=job_table_entry.driver_ip_address,
|
||||
pid=job_table_entry.driver_pid,
|
||||
)
|
||||
submission_job_drivers[job_submission_id] = driver
|
||||
|
||||
return jobs, submission_job_drivers
|
||||
|
||||
|
||||
async def find_job_by_ids(
|
||||
gcs_aio_client: GcsAioClient,
|
||||
job_manager: "JobManager", # noqa: F821
|
||||
job_or_submission_id: str,
|
||||
) -> Optional[JobDetails]:
|
||||
"""
|
||||
Attempts to find the job with a given submission_id or job id.
|
||||
"""
|
||||
# First try to find by job_id
|
||||
driver_jobs, submission_job_drivers = await get_driver_jobs(gcs_aio_client)
|
||||
job = driver_jobs.get(job_or_submission_id)
|
||||
if job:
|
||||
return job
|
||||
# Try to find a driver with the given id
|
||||
submission_id = next(
|
||||
(
|
||||
id
|
||||
for id, driver in submission_job_drivers.items()
|
||||
if driver.id == job_or_submission_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not submission_id:
|
||||
# If we didn't find a driver with the given id,
|
||||
# then lets try to search for a submission with given id
|
||||
submission_id = job_or_submission_id
|
||||
|
||||
job_info = await job_manager.get_job_info(submission_id)
|
||||
if job_info:
|
||||
driver = submission_job_drivers.get(submission_id)
|
||||
job = JobDetails(
|
||||
**dataclasses.asdict(job_info),
|
||||
submission_id=submission_id,
|
||||
job_id=driver.id if driver else None,
|
||||
driver_info=driver,
|
||||
type=JobType.SUBMISSION,
|
||||
)
|
||||
return job
|
||||
|
||||
return None
|
||||
|
|
|
@ -415,6 +415,9 @@ class GcsAioClient:
|
|||
self._heartbeat_info_stub = gcs_service_pb2_grpc.HeartbeatInfoGcsServiceStub(
|
||||
self._channel.channel()
|
||||
)
|
||||
self._job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
||||
self._channel.channel()
|
||||
)
|
||||
|
||||
@_auto_reconnect
|
||||
async def check_alive(
|
||||
|
@ -521,6 +524,14 @@ class GcsAioClient:
|
|||
f"due to error {reply.status.message}"
|
||||
)
|
||||
|
||||
@_auto_reconnect
|
||||
async def get_all_job_info(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> gcs_service_pb2.GetAllJobInfoReply:
|
||||
req = gcs_service_pb2.GetAllJobInfoRequest()
|
||||
reply = await self._job_info_stub.GetAllJobInfo(req, timeout=timeout)
|
||||
return reply
|
||||
|
||||
|
||||
def use_gcs_for_bootstrap():
|
||||
"""In the current version of Ray, we always use the GCS to bootstrap.
|
||||
|
|
Loading…
Add table
Reference in a new issue