mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
139 lines
5.3 KiB
Python
139 lines
5.3 KiB
Python
from base64 import b64encode
|
|
import logging
|
|
from pathlib import Path
|
|
import tempfile
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
try:
|
|
from pydantic import BaseModel
|
|
from pydantic.main import ModelMetaclass
|
|
except ImportError:
|
|
BaseModel = object
|
|
ModelMetaclass = object
|
|
|
|
import requests
|
|
|
|
from ray._private.runtime_env.packaging import (
|
|
create_package, get_uri_for_directory, parse_uri)
|
|
from ray._private.job_manager import JobStatus
|
|
from ray.dashboard.modules.job.data_types import (
|
|
GetPackageRequest, GetPackageResponse, UploadPackageRequest, JobSpec,
|
|
JobSubmitRequest, JobSubmitResponse, JobStatusRequest, JobStatusResponse,
|
|
JobLogsRequest, JobLogsResponse)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
class JobSubmissionClient:
|
|
def __init__(self, address: str):
|
|
self._address: str = address.rstrip("/")
|
|
self._test_connection()
|
|
|
|
def _test_connection(self):
|
|
try:
|
|
assert not self._package_exists("gcs://FAKE_URI")
|
|
except requests.exceptions.ConnectionError:
|
|
raise ConnectionError(
|
|
f"Failed to connect to Ray at address: {self._address}.")
|
|
|
|
def _do_request(
|
|
self,
|
|
method: str,
|
|
endpoint: str,
|
|
data: BaseModel,
|
|
response_type: Optional[ModelMetaclass] = None) -> Dict[Any, Any]:
|
|
url = f"{self._address}/{endpoint}"
|
|
json_payload = data.dict()
|
|
logger.debug(f"Sending request to {url} with payload {json_payload}.")
|
|
r = requests.request(method, url, json=json_payload)
|
|
r.raise_for_status()
|
|
|
|
response_json = r.json()
|
|
if not response_json["result"]: # Indicates failure.
|
|
raise Exception(response_json["msg"])
|
|
|
|
if response_type is None:
|
|
return None
|
|
else:
|
|
# Dashboard "framework" returns double-nested "data" field...
|
|
return response_type(**response_json["data"]["data"])
|
|
|
|
def _package_exists(self, package_uri: str) -> bool:
|
|
req = GetPackageRequest(package_uri=package_uri)
|
|
resp = self._do_request(
|
|
"GET", "package", req, response_type=GetPackageResponse)
|
|
return resp.package_exists
|
|
|
|
def _upload_package(self,
|
|
package_uri: str,
|
|
package_path: str,
|
|
include_parent_dir: Optional[bool] = False,
|
|
excludes: Optional[List[str]] = None) -> bool:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
package_name = parse_uri(package_uri)[1]
|
|
package_file = Path(tmp_dir) / package_name
|
|
create_package(
|
|
package_path,
|
|
package_file,
|
|
include_parent_dir=include_parent_dir,
|
|
excludes=excludes)
|
|
req = UploadPackageRequest(
|
|
package_uri=package_uri,
|
|
encoded_package_bytes=b64encode(package_file.read_bytes()))
|
|
self._do_request("PUT", "package", req)
|
|
package_file.unlink()
|
|
|
|
def _upload_package_if_needed(self,
|
|
package_path: str,
|
|
excludes: Optional[List[str]] = None) -> str:
|
|
package_uri: str = get_uri_for_directory(
|
|
package_path, excludes=excludes)
|
|
if not self._package_exists(package_uri):
|
|
logger.info(f"Uploading package {package_uri}.")
|
|
self._upload_package(package_uri, package_path, excludes=excludes)
|
|
else:
|
|
logger.info(
|
|
f"Package {package_uri} already exists, skipping upload.")
|
|
|
|
return package_uri
|
|
|
|
def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]):
|
|
if "working_dir" in runtime_env:
|
|
working_dir = runtime_env["working_dir"]
|
|
try:
|
|
parse_uri(working_dir)
|
|
is_uri = True
|
|
logger.debug("working_dir is already a valid URI.")
|
|
except ValueError:
|
|
is_uri = False
|
|
|
|
if not is_uri:
|
|
logger.debug("working_dir is not a URI, attempting to upload.")
|
|
package_uri = self._upload_package_if_needed(
|
|
working_dir, excludes=runtime_env.get("excludes", None))
|
|
runtime_env["working_dir"] = package_uri
|
|
|
|
def submit_job(self,
|
|
entrypoint: str,
|
|
runtime_env: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, str]] = None) -> str:
|
|
runtime_env = runtime_env or {}
|
|
metadata = metadata or {}
|
|
|
|
self._upload_working_dir_if_needed(runtime_env)
|
|
job_spec = JobSpec(
|
|
entrypoint=entrypoint, runtime_env=runtime_env, metadata=metadata)
|
|
req = JobSubmitRequest(job_spec=job_spec)
|
|
resp = self._do_request("POST", "submit", req, JobSubmitResponse)
|
|
return resp.job_id
|
|
|
|
def get_job_status(self, job_id: str) -> JobStatus:
|
|
req = JobStatusRequest(job_id=job_id)
|
|
resp = self._do_request("GET", "status", req, JobStatusResponse)
|
|
return resp.job_status
|
|
|
|
def get_job_logs(self, job_id: str) -> Tuple[str, str]:
|
|
req = JobLogsRequest(job_id=job_id)
|
|
resp = self._do_request("GET", "logs", req, JobLogsResponse)
|
|
return resp.stdout, resp.stderr
|