mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
161 lines
4.7 KiB
Python
161 lines
4.7 KiB
Python
import dataclasses
|
|
import logging
|
|
from typing import Any, Dict, Iterator, Optional
|
|
|
|
try:
|
|
import aiohttp
|
|
import requests
|
|
except ImportError:
|
|
aiohttp = None
|
|
requests = None
|
|
|
|
from ray.dashboard.modules.job.common import (
|
|
JobStatus,
|
|
JobSubmitRequest,
|
|
JobSubmitResponse,
|
|
JobStopResponse,
|
|
JobInfo,
|
|
JobLogsResponse,
|
|
)
|
|
from ray.dashboard.modules.dashboard_sdk import SubmissionClient
|
|
|
|
from ray.runtime_env import RuntimeEnv
|
|
|
|
from ray.util.annotations import PublicAPI
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class JobSubmissionClient(SubmissionClient):
|
|
def __init__(
|
|
self,
|
|
address: str,
|
|
create_cluster_if_needed=False,
|
|
cookies: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
headers: Optional[Dict[str, Any]] = None,
|
|
):
|
|
if requests is None:
|
|
raise RuntimeError(
|
|
"The Ray jobs CLI & SDK require the ray[default] "
|
|
"installation: `pip install 'ray[default']``"
|
|
)
|
|
super().__init__(
|
|
address=address,
|
|
create_cluster_if_needed=create_cluster_if_needed,
|
|
cookies=cookies,
|
|
metadata=metadata,
|
|
headers=headers,
|
|
)
|
|
self._check_connection_and_version(
|
|
min_version="1.9",
|
|
version_error_message="Jobs API is not supported on the Ray "
|
|
"cluster. Please ensure the cluster is "
|
|
"running Ray 1.9 or higher.",
|
|
)
|
|
|
|
@PublicAPI(stability="beta")
|
|
def submit_job(
|
|
self,
|
|
*,
|
|
entrypoint: str,
|
|
job_id: Optional[str] = None,
|
|
runtime_env: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, str]] = None,
|
|
) -> str:
|
|
runtime_env = runtime_env or {}
|
|
metadata = metadata or {}
|
|
metadata.update(self._default_metadata)
|
|
|
|
self._upload_working_dir_if_needed(runtime_env)
|
|
self._upload_py_modules_if_needed(runtime_env)
|
|
|
|
# Run the RuntimeEnv constructor to parse local pip/conda requirements files.
|
|
runtime_env = RuntimeEnv(**runtime_env).to_dict()
|
|
|
|
req = JobSubmitRequest(
|
|
entrypoint=entrypoint,
|
|
job_id=job_id,
|
|
runtime_env=runtime_env,
|
|
metadata=metadata,
|
|
)
|
|
|
|
logger.debug(f"Submitting job with job_id={job_id}.")
|
|
r = self._do_request("POST", "/api/jobs/", json_data=dataclasses.asdict(req))
|
|
|
|
if r.status_code == 200:
|
|
return JobSubmitResponse(**r.json()).job_id
|
|
else:
|
|
self._raise_error(r)
|
|
|
|
@PublicAPI(stability="beta")
|
|
def stop_job(
|
|
self,
|
|
job_id: str,
|
|
) -> bool:
|
|
logger.debug(f"Stopping job with job_id={job_id}.")
|
|
r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
|
|
|
|
if r.status_code == 200:
|
|
return JobStopResponse(**r.json()).stopped
|
|
else:
|
|
self._raise_error(r)
|
|
|
|
@PublicAPI(stability="beta")
|
|
def get_job_info(
|
|
self,
|
|
job_id: str,
|
|
) -> JobInfo:
|
|
r = self._do_request("GET", f"/api/jobs/{job_id}")
|
|
|
|
if r.status_code == 200:
|
|
return JobInfo(**r.json())
|
|
else:
|
|
self._raise_error(r)
|
|
|
|
@PublicAPI(stability="beta")
|
|
def list_jobs(self) -> Dict[str, JobInfo]:
|
|
r = self._do_request("GET", "/api/jobs/")
|
|
|
|
if r.status_code == 200:
|
|
jobs_info_json = r.json()
|
|
jobs_info = {
|
|
job_id: JobInfo(**job_info_json)
|
|
for job_id, job_info_json in jobs_info_json.items()
|
|
}
|
|
return jobs_info
|
|
else:
|
|
self._raise_error(r)
|
|
|
|
@PublicAPI(stability="beta")
|
|
def get_job_status(self, job_id: str) -> JobStatus:
|
|
return self.get_job_info(job_id).status
|
|
|
|
@PublicAPI(stability="beta")
|
|
def get_job_logs(self, job_id: str) -> str:
|
|
r = self._do_request("GET", f"/api/jobs/{job_id}/logs")
|
|
|
|
if r.status_code == 200:
|
|
return JobLogsResponse(**r.json()).logs
|
|
else:
|
|
self._raise_error(r)
|
|
|
|
@PublicAPI(stability="beta")
|
|
async def tail_job_logs(self, job_id: str) -> Iterator[str]:
|
|
async with aiohttp.ClientSession(
|
|
cookies=self._cookies, headers=self._headers
|
|
) as session:
|
|
ws = await session.ws_connect(
|
|
f"{self._address}/api/jobs/{job_id}/logs/tail"
|
|
)
|
|
|
|
while True:
|
|
msg = await ws.receive()
|
|
|
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
yield msg.data
|
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
break
|
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
pass
|