mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[jobs] Initial http jobs server on head node (#19657)
This commit is contained in:
parent
d656b3a6d7
commit
e53fecfbd5
10 changed files with 235 additions and 880 deletions
64
dashboard/modules/job/data_types.py
Normal file
64
dashboard/modules/job/data_types.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
STOPPED = "STOPPED"
|
||||
SUCCEEDED = "SUCCEEDED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class JobSpec(BaseModel):
|
||||
# Dict to setup execution environment, better to have schema for this
|
||||
runtime_env: dict
|
||||
# Command to start execution, ex: "python script.py"
|
||||
entrypoint: str
|
||||
# Metadata to pass in to configure job behavior or use as tags
|
||||
# Required by Anyscale product and already supported in Ray drivers
|
||||
metadata: dict
|
||||
# Likely there will be more fields needed later on for different apps
|
||||
# but we should keep it minimal and delegate policies to job manager
|
||||
|
||||
|
||||
# ==== Job Submit ====
|
||||
|
||||
|
||||
class JobSubmitRequest(BaseModel):
|
||||
job_spec: JobSpec
|
||||
# Globally unique job id. It’s recommended to generate this id from
|
||||
# external job manager first, then pass into this API.
|
||||
# If job server never had a job running with given id:
|
||||
# - Start new job execution
|
||||
# Else if job server has a running job with given id:
|
||||
# - Fail, deployment update and reconfigure should happen in job manager
|
||||
job_id: str = None
|
||||
|
||||
|
||||
class JobSubmitResponse(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
# ==== Job Status ====
|
||||
|
||||
|
||||
class JobStatusRequest(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
class JobStatusResponse(BaseModel):
|
||||
job_status: JobStatus
|
||||
|
||||
|
||||
# ==== Job Logs ====
|
||||
|
||||
|
||||
class JobLogsRequest(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
# TODO(jiaodong): Support log streaming #19415
|
||||
class JobLogsResponse(BaseModel):
|
||||
stdout: str
|
||||
stderr: str
|
|
@ -1,320 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import itertools
|
||||
import subprocess
|
||||
import sys
|
||||
import secrets
|
||||
import uuid
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import attr
|
||||
from attr.validators import instance_of
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
from ray.dashboard.utils import create_task
|
||||
from ray.dashboard.modules.job import job_consts
|
||||
from ray.dashboard.modules.job.job_description import JobDescription
|
||||
from ray.core.generated import job_agent_pb2
|
||||
from ray.core.generated import job_agent_pb2_grpc
|
||||
from ray.core.generated import agent_manager_pb2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(kw_only=True, slots=True)
|
||||
class JobInfo(JobDescription):
|
||||
# TODO(fyrestone): We should use job id instead of unique id.
|
||||
unique_id = attr.ib(type=str, validator=instance_of(str))
|
||||
# The temp directory.
|
||||
temp_dir = attr.ib(type=str, validator=instance_of(str))
|
||||
# The log directory.
|
||||
log_dir = attr.ib(type=str, validator=instance_of(str))
|
||||
# The driver process instance.
|
||||
driver = attr.ib(
|
||||
type=Union[None, asyncio.subprocess.Process], default=None)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# Support json values for env.
|
||||
self.env = {
|
||||
k: v if isinstance(v, str) else json.dumps(v)
|
||||
for k, v in self.env.items()
|
||||
}
|
||||
|
||||
|
||||
class JobProcessor:
|
||||
"""Wraps the job info and provides common utils to download packages,
|
||||
start drivers, etc.
|
||||
|
||||
Args:
|
||||
job_info (JobInfo): The job info.
|
||||
"""
|
||||
_cmd_index_gen = itertools.count(1)
|
||||
|
||||
def __init__(self, job_info):
|
||||
assert isinstance(job_info, JobInfo)
|
||||
self._job_info = job_info
|
||||
|
||||
async def _download_package(self, http_session, url, filename):
|
||||
unique_id = self._job_info.unique_id
|
||||
cmd_index = next(self._cmd_index_gen)
|
||||
logger.info("[%s] Start download[%s] %s to %s", unique_id, cmd_index,
|
||||
url, filename)
|
||||
async with http_session.get(url, ssl=False) as response:
|
||||
with open(filename, "wb") as f:
|
||||
while True:
|
||||
chunk = await response.content.read(
|
||||
job_consts.DOWNLOAD_BUFFER_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
logger.info("[%s] Finished download[%s] %s to %s", unique_id,
|
||||
cmd_index, url, filename)
|
||||
|
||||
async def _unpack_package(self, filename, path):
|
||||
code = f"import shutil; " \
|
||||
f"shutil.unpack_archive({repr(filename)}, {repr(path)})"
|
||||
unzip_cmd = [self._get_current_python(), "-c", code]
|
||||
await self._check_output_cmd(unzip_cmd)
|
||||
|
||||
async def _check_output_cmd(self, cmd):
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE)
|
||||
unique_id = self._job_info.unique_id
|
||||
cmd_index = next(self._cmd_index_gen)
|
||||
proc.cmd_index = cmd_index
|
||||
logger.info("[%s] Run cmd[%s] %s", unique_id, cmd_index, repr(cmd))
|
||||
stdout, stderr = await proc.communicate()
|
||||
stdout = stdout.decode("utf-8")
|
||||
logger.info("[%s] Output of cmd[%s]: %s", unique_id, cmd_index, stdout)
|
||||
if proc.returncode != 0:
|
||||
stderr = stderr.decode("utf-8")
|
||||
logger.error("[%s] Error of cmd[%s]: %s", unique_id, cmd_index,
|
||||
stderr)
|
||||
raise subprocess.CalledProcessError(
|
||||
proc.returncode, cmd, output=stdout, stderr=stderr)
|
||||
return stdout
|
||||
|
||||
async def _start_driver(self, cmd, stdout, stderr, env):
|
||||
unique_id = self._job_info.unique_id
|
||||
job_package_dir = job_consts.JOB_UNPACK_DIR.format(
|
||||
temp_dir=self._job_info.temp_dir, unique_id=unique_id)
|
||||
cmd_str = subprocess.list2cmdline(cmd)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
env={
|
||||
**os.environ,
|
||||
**env,
|
||||
},
|
||||
cwd=job_package_dir,
|
||||
)
|
||||
logger.info("[%s] Start driver cmd %s with pid %s", unique_id,
|
||||
repr(cmd_str), proc.pid)
|
||||
return proc
|
||||
|
||||
@staticmethod
|
||||
def _get_current_python():
|
||||
return sys.executable
|
||||
|
||||
@staticmethod
|
||||
def _new_log_files(log_dir, filename):
|
||||
if log_dir is None:
|
||||
return None, None
|
||||
stdout = open(
|
||||
os.path.join(log_dir, filename + ".out"), "a", buffering=1)
|
||||
stderr = open(
|
||||
os.path.join(log_dir, filename + ".err"), "a", buffering=1)
|
||||
return stdout, stderr
|
||||
|
||||
@abstractmethod
|
||||
async def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class DownloadPackage(JobProcessor):
|
||||
""" Download the job package.
|
||||
|
||||
Args:
|
||||
job_info (JobInfo): The job info.
|
||||
http_session (aiohttp.ClientSession): The client session.
|
||||
"""
|
||||
|
||||
def __init__(self, job_info, http_session):
|
||||
super().__init__(job_info)
|
||||
self._http_session = http_session
|
||||
|
||||
async def run(self):
|
||||
temp_dir = self._job_info.temp_dir
|
||||
unique_id = self._job_info.unique_id
|
||||
filename = job_consts.DOWNLOAD_PACKAGE_FILE.format(
|
||||
temp_dir=temp_dir, unique_id=unique_id)
|
||||
unpack_dir = job_consts.JOB_UNPACK_DIR.format(
|
||||
temp_dir=temp_dir, unique_id=unique_id)
|
||||
url = self._job_info.runtime_env.working_dir
|
||||
await self._download_package(self._http_session, url, filename)
|
||||
await self._unpack_package(filename, unpack_dir)
|
||||
|
||||
|
||||
class StartPythonDriver(JobProcessor):
|
||||
""" Start the driver for Python job.
|
||||
|
||||
Args:
|
||||
job_info (JobInfo): The job info.
|
||||
redis_address (tuple): The (ip, port) of redis.
|
||||
redis_password (str): The password of redis.
|
||||
"""
|
||||
|
||||
_template = """import sys
|
||||
sys.path.append({import_path})
|
||||
import ray
|
||||
from ray._private.utils import hex_to_binary
|
||||
ray.init(ignore_reinit_error=True,
|
||||
address={redis_address},
|
||||
_redis_password={redis_password},
|
||||
job_config=ray.job_config.JobConfig({job_config_args}),
|
||||
)
|
||||
import {driver_entry}
|
||||
{driver_entry}.main({driver_args})
|
||||
# If the driver exits normally, we invoke Ray.shutdown() again
|
||||
# here, in case the user code forgot to invoke it.
|
||||
ray.shutdown()
|
||||
"""
|
||||
|
||||
def __init__(self, job_info, redis_address, redis_password):
|
||||
super().__init__(job_info)
|
||||
self._redis_address = redis_address
|
||||
self._redis_password = redis_password
|
||||
|
||||
def _gen_driver_code(self):
|
||||
temp_dir = self._job_info.temp_dir
|
||||
unique_id = self._job_info.unique_id
|
||||
job_package_dir = job_consts.JOB_UNPACK_DIR.format(
|
||||
temp_dir=temp_dir, unique_id=unique_id)
|
||||
driver_entry_file = job_consts.JOB_DRIVER_ENTRY_FILE.format(
|
||||
temp_dir=temp_dir, unique_id=unique_id, uuid=uuid.uuid4())
|
||||
ip, port = self._redis_address
|
||||
|
||||
# Per job config
|
||||
job_config_items = {
|
||||
"runtime_env": {
|
||||
"env_vars": self._job_info.env
|
||||
},
|
||||
"code_search_path": [job_package_dir],
|
||||
}
|
||||
|
||||
job_config_args = ", ".join(f"{key}={repr(value)}"
|
||||
for key, value in job_config_items.items()
|
||||
if value is not None)
|
||||
driver_args = ", ".join([repr(x) for x in self._job_info.driver_args])
|
||||
driver_code = self._template.format(
|
||||
job_config_args=job_config_args,
|
||||
import_path=repr(job_package_dir),
|
||||
redis_address=repr(ip + ":" + str(port)),
|
||||
redis_password=repr(self._redis_password),
|
||||
driver_entry=self._job_info.driver_entry,
|
||||
driver_args=driver_args)
|
||||
with open(driver_entry_file, "w") as fp:
|
||||
fp.write(driver_code)
|
||||
return driver_entry_file
|
||||
|
||||
async def run(self):
|
||||
python = self._get_current_python()
|
||||
driver_file = self._gen_driver_code()
|
||||
driver_cmd = [python, "-u", driver_file]
|
||||
stdout_file, stderr_file = self._new_log_files(
|
||||
self._job_info.log_dir, f"driver-{self._job_info.unique_id}")
|
||||
return await self._start_driver(driver_cmd, stdout_file, stderr_file,
|
||||
self._job_info.env)
|
||||
|
||||
|
||||
class JobAgent(dashboard_utils.DashboardAgentModule,
|
||||
job_agent_pb2_grpc.JobAgentServiceServicer):
|
||||
""" The JobAgentService defined in job_agent.proto for initializing /
|
||||
cleaning job environments.
|
||||
"""
|
||||
|
||||
async def InitializeJobEnv(self, request, context):
|
||||
# TODO(fyrestone): Handle duplicated InitializeJobEnv requests
|
||||
# when initializing job environment.
|
||||
# TODO(fyrestone): Support reinitialize job environment.
|
||||
|
||||
# TODO(fyrestone): Use job id instead of unique id.
|
||||
unique_id = secrets.token_hex(6)
|
||||
|
||||
# Parse the job description from the request.
|
||||
try:
|
||||
job_description_data = json.loads(request.job_description)
|
||||
job_info = JobInfo(
|
||||
unique_id=unique_id,
|
||||
temp_dir=self._dashboard_agent.temp_dir,
|
||||
log_dir=self._dashboard_agent.log_dir,
|
||||
**job_description_data)
|
||||
except json.JSONDecodeError as ex:
|
||||
error_message = str(ex)
|
||||
error_message += f", job_payload:\n{request.job_description}"
|
||||
logger.error("[%s] Initialize job environment failed, %s.",
|
||||
unique_id, error_message)
|
||||
return job_agent_pb2.InitializeJobEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message=error_message)
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
return job_agent_pb2.InitializeJobEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message=traceback.format_exc())
|
||||
|
||||
async def _initialize_job_env():
|
||||
os.makedirs(
|
||||
job_consts.JOB_DIR.format(
|
||||
temp_dir=job_info.temp_dir, unique_id=unique_id),
|
||||
exist_ok=True)
|
||||
# Download the job package.
|
||||
await DownloadPackage(job_info,
|
||||
self._dashboard_agent.http_session).run()
|
||||
# Start the driver.
|
||||
logger.info("[%s] Starting driver.", unique_id)
|
||||
language = job_info.language
|
||||
if language == job_consts.PYTHON:
|
||||
driver = await StartPythonDriver(
|
||||
job_info, self._dashboard_agent.redis_address,
|
||||
self._dashboard_agent.redis_password).run()
|
||||
else:
|
||||
raise Exception(f"Unsupported language type: {language}")
|
||||
job_info.driver = driver
|
||||
|
||||
initialize_task = create_task(_initialize_job_env())
|
||||
|
||||
try:
|
||||
await initialize_task
|
||||
except asyncio.CancelledError:
|
||||
logger.error("[%s] Initialize job environment has been cancelled.",
|
||||
unique_id)
|
||||
return job_agent_pb2.InitializeJobEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message="InitializeJobEnv has been cancelled, "
|
||||
"did you call CleanJobEnv?")
|
||||
except Exception as ex:
|
||||
logger.exception(ex)
|
||||
return job_agent_pb2.InitializeJobEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
||||
error_message=traceback.format_exc())
|
||||
|
||||
driver_pid = 0
|
||||
if job_info.driver:
|
||||
driver_pid = job_info.driver.pid
|
||||
|
||||
logger.info(
|
||||
"[%s] Job environment initialized, "
|
||||
"the driver (pid=%s) started.", unique_id, driver_pid)
|
||||
return job_agent_pb2.InitializeJobEnvReply(
|
||||
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
||||
driver_pid=driver_pid)
|
||||
|
||||
async def run(self, server):
|
||||
job_agent_pb2_grpc.add_JobAgentServiceServicer_to_server(self, server)
|
|
@ -1,18 +0,0 @@
|
|||
import os
|
||||
from ray.core.generated import common_pb2
|
||||
|
||||
# Job agent consts
|
||||
# TODO(fyrestone): We should use job id instead of unique_id.
|
||||
JOB_DIR = "{temp_dir}/job/{unique_id}/"
|
||||
JOB_UNPACK_DIR = os.path.join(JOB_DIR, "package")
|
||||
JOB_DRIVER_ENTRY_FILE = os.path.join(JOB_DIR, "driver-{uuid}.py")
|
||||
# Downloader constants
|
||||
DOWNLOAD_BUFFER_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
DOWNLOAD_PACKAGE_FILE = os.path.join(JOB_DIR, "package.zip")
|
||||
# Redis key
|
||||
JOB_CHANNEL = "JOB"
|
||||
RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS = 2
|
||||
# Languages
|
||||
PYTHON = common_pb2.Language.Name(common_pb2.Language.PYTHON)
|
||||
JAVA = common_pb2.Language.Name(common_pb2.Language.JAVA)
|
||||
CPP = common_pb2.Language.Name(common_pb2.Language.CPP)
|
|
@ -1,58 +0,0 @@
|
|||
import attr
|
||||
from attr.validators import instance_of, in_, deep_mapping
|
||||
from ray.core.generated import common_pb2
|
||||
|
||||
|
||||
@attr.s(repr=False, slots=True, hash=True)
|
||||
class _AnyValidator(object):
|
||||
def __call__(self, inst, name, value):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<any validator for any type>"
|
||||
|
||||
|
||||
def any_():
|
||||
return _AnyValidator()
|
||||
|
||||
|
||||
@attr.s(kw_only=True, slots=True)
|
||||
class RuntimeEnv:
|
||||
# The url to download the job package archive. The archive format is
|
||||
# one of “zip”, “tar”, “gztar”, “bztar”, or “xztar”. Please refer to
|
||||
# https://docs.python.org/3/library/shutil.html#shutil.unpack_archive
|
||||
working_dir = attr.ib(type=str, validator=instance_of(str))
|
||||
|
||||
|
||||
@attr.s(kw_only=True, slots=True)
|
||||
class JobDescription:
|
||||
# The job driver language, this field determines how to start the
|
||||
# driver. The value is one of the names of enum Language defined in
|
||||
# common.proto, e.g. PYTHON
|
||||
language = attr.ib(type=str, validator=in_(common_pb2.Language.keys()))
|
||||
# The runtime_env (RuntimeEnvDict) for the job config.
|
||||
runtime_env = attr.ib(
|
||||
type=RuntimeEnv, converter=lambda kw: RuntimeEnv(**kw))
|
||||
# The entry to start the driver.
|
||||
# PYTHON:
|
||||
# - The basename of driver filename without extension in the job
|
||||
# package archive.
|
||||
# JAVA:
|
||||
# - The driver class full name in the job package archive.
|
||||
driver_entry = attr.ib(type=str, validator=instance_of(str))
|
||||
# The driver arguments in list.
|
||||
# PYTHON:
|
||||
# - The arguments to pass to the main() function in driver entry.
|
||||
# e.g. [1, False, 3.14, "abc"]
|
||||
# JAVA:
|
||||
# - The arguments to pass to the driver command line.
|
||||
# e.g. ["-custom-arg", "abc"]
|
||||
driver_args = attr.ib(type=list, validator=instance_of(list), default=[])
|
||||
# The environment vars to pass to job config, type of keys should be str.
|
||||
env = attr.ib(
|
||||
type=dict,
|
||||
validator=deep_mapping(
|
||||
key_validator=instance_of(str),
|
||||
value_validator=any_(),
|
||||
mapping_validator=instance_of(dict)),
|
||||
default={})
|
|
@ -1,157 +1,92 @@
|
|||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import aiohttp.web
|
||||
from aioredis.pubsub import Receiver
|
||||
from functools import wraps
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import ray._private.utils
|
||||
import ray._private.gcs_utils as gcs_utils
|
||||
import ray
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
from ray.dashboard.modules.job import job_consts
|
||||
from ray.dashboard.modules.job.job_description import JobDescription
|
||||
from ray.core.generated import agent_manager_pb2
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated import job_agent_pb2
|
||||
from ray.core.generated import job_agent_pb2_grpc
|
||||
from ray.dashboard.datacenter import (
|
||||
DataSource,
|
||||
GlobalSignals,
|
||||
)
|
||||
|
||||
from ray._private.job_manager import JobManager
|
||||
from ray.dashboard.modules.job.data_types import (
|
||||
JobStatus, JobSubmitRequest, JobSubmitResponse, JobStatusRequest,
|
||||
JobStatusResponse, JobLogsRequest, JobLogsResponse)
|
||||
from ray.experimental.internal_kv import (_initialize_internal_kv,
|
||||
_internal_kv_initialized)
|
||||
logger = logging.getLogger(__name__)
|
||||
routes = dashboard_utils.ClassMethodRouteTable
|
||||
|
||||
RAY_INTERNAL_JOBS_NAMESPACE = "_ray_internal_jobs_"
|
||||
|
||||
def job_table_data_to_dict(message):
|
||||
decode_keys = {"jobId", "rayletId"}
|
||||
return dashboard_utils.message_to_dict(
|
||||
message, decode_keys, including_default_value_fields=True)
|
||||
|
||||
def _ensure_ray_initialized(f: Callable) -> Callable:
|
||||
@wraps(f)
|
||||
def check(self, *args, **kwargs):
|
||||
if not ray.is_initialized():
|
||||
ray.init(address="auto", namespace=RAY_INTERNAL_JOBS_NAMESPACE)
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return check
|
||||
|
||||
|
||||
class JobHead(dashboard_utils.DashboardHeadModule):
|
||||
def __init__(self, dashboard_head):
|
||||
super().__init__(dashboard_head)
|
||||
# JobInfoGcsServiceStub
|
||||
self._gcs_job_info_stub = None
|
||||
|
||||
@routes.post("/jobs")
|
||||
async def submit_job(self, req) -> aiohttp.web.Response:
|
||||
job_description_data = dict(await req.json())
|
||||
# Validate the job description data.
|
||||
try:
|
||||
JobDescription(**job_description_data)
|
||||
except Exception as ex:
|
||||
return dashboard_utils.rest_response(
|
||||
success=False, message=f"Failed to submit job: {ex}")
|
||||
# Initialize internal KV to be used by the working_dir setup code.
|
||||
_initialize_internal_kv(dashboard_head.gcs_client)
|
||||
assert _internal_kv_initialized()
|
||||
|
||||
# TODO(fyrestone): Choose a random agent to start the driver
|
||||
# for this job.
|
||||
node_id, ports = next(iter(DataSource.agents.items()))
|
||||
ip = DataSource.node_id_to_ip[node_id]
|
||||
address = f"{ip}:{ports[1]}"
|
||||
options = (("grpc.enable_http_proxy", 0), )
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
address, options, asynchronous=True)
|
||||
self._job_manager = None
|
||||
|
||||
stub = job_agent_pb2_grpc.JobAgentServiceStub(channel)
|
||||
request = job_agent_pb2.InitializeJobEnvRequest(
|
||||
job_description=json.dumps(job_description_data))
|
||||
# TODO(fyrestone): It's better not to wait the RPC InitializeJobEnv.
|
||||
reply = await stub.InitializeJobEnv(request)
|
||||
# TODO(fyrestone): We should reply a job id for the submitted job.
|
||||
if reply.status == agent_manager_pb2.AGENT_RPC_STATUS_OK:
|
||||
logger.info("Succeeded to submit job.")
|
||||
return dashboard_utils.rest_response(
|
||||
success=True, message="Job submitted.")
|
||||
else:
|
||||
logger.info("Failed to submit job.")
|
||||
return dashboard_utils.rest_response(
|
||||
success=False,
|
||||
message=f"Failed to submit job: {reply.error_message}")
|
||||
@routes.post("/submit")
|
||||
@_ensure_ray_initialized
|
||||
async def submit(self, req) -> aiohttp.web.Response:
|
||||
req_data = dict(await req.json())
|
||||
submit_request = JobSubmitRequest(**req_data)
|
||||
self._job_manager.submit_job(submit_request.job_id,
|
||||
submit_request.job_spec.entrypoint,
|
||||
submit_request.job_spec.runtime_env)
|
||||
|
||||
@routes.get("/jobs")
|
||||
@dashboard_utils.aiohttp_cache
|
||||
async def get_all_jobs(self, req) -> aiohttp.web.Response:
|
||||
view = req.query.get("view")
|
||||
if view == "summary":
|
||||
return dashboard_utils.rest_response(
|
||||
success=True,
|
||||
message="All job summary fetched.",
|
||||
summary=list(DataSource.jobs.values()))
|
||||
else:
|
||||
return dashboard_utils.rest_response(
|
||||
success=False, message="Unknown view {}".format(view))
|
||||
resp = JobSubmitResponse(job_id=submit_request.job_id)
|
||||
return dashboard_utils.rest_response(
|
||||
success=True,
|
||||
convert_google_style=False,
|
||||
data=resp.dict(),
|
||||
message=f"Submitted job {submit_request.job_id}")
|
||||
|
||||
@routes.get("/jobs/{job_id}")
|
||||
@dashboard_utils.aiohttp_cache
|
||||
async def get_job(self, req) -> aiohttp.web.Response:
|
||||
job_id = req.match_info.get("job_id")
|
||||
view = req.query.get("view")
|
||||
if view is None:
|
||||
job_detail = {
|
||||
"jobInfo": DataSource.jobs.get(job_id, {}),
|
||||
"jobActors": DataSource.job_actors.get(job_id, {}),
|
||||
"jobWorkers": DataSource.job_workers.get(job_id, []),
|
||||
}
|
||||
await GlobalSignals.job_info_fetched.send(job_detail)
|
||||
return dashboard_utils.rest_response(
|
||||
success=True, message="Job detail fetched.", detail=job_detail)
|
||||
else:
|
||||
return dashboard_utils.rest_response(
|
||||
success=False, message="Unknown view {}".format(view))
|
||||
@routes.get("/status")
|
||||
@_ensure_ray_initialized
|
||||
async def status(self, req) -> aiohttp.web.Response:
|
||||
req_data = dict(await req.json())
|
||||
status_request = JobStatusRequest(**req_data)
|
||||
|
||||
async def _update_jobs(self):
|
||||
# Subscribe job channel.
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
receiver = Receiver()
|
||||
status: JobStatus = self._job_manager.get_job_status(
|
||||
status_request.job_id)
|
||||
resp = JobStatusResponse(job_status=status)
|
||||
return dashboard_utils.rest_response(
|
||||
success=True,
|
||||
convert_google_style=False,
|
||||
data=resp.dict(),
|
||||
message=f"Queried status for job {status_request.job_id}")
|
||||
|
||||
key = f"{job_consts.JOB_CHANNEL}:*"
|
||||
pattern = receiver.pattern(key)
|
||||
await aioredis_client.psubscribe(pattern)
|
||||
logger.info("Subscribed to %s", key)
|
||||
@routes.get("/logs")
|
||||
@_ensure_ray_initialized
|
||||
async def logs(self, req) -> aiohttp.web.Response:
|
||||
req_data = dict(await req.json())
|
||||
logs_request = JobLogsRequest(**req_data)
|
||||
|
||||
# Get all job info.
|
||||
while True:
|
||||
try:
|
||||
logger.info("Getting all job info from GCS.")
|
||||
request = gcs_service_pb2.GetAllJobInfoRequest()
|
||||
reply = await self._gcs_job_info_stub.GetAllJobInfo(
|
||||
request, timeout=5)
|
||||
if reply.status.code == 0:
|
||||
jobs = {}
|
||||
for job_table_data in reply.job_info_list:
|
||||
data = job_table_data_to_dict(job_table_data)
|
||||
jobs[data["jobId"]] = data
|
||||
# Update jobs.
|
||||
DataSource.jobs.reset(jobs)
|
||||
logger.info("Received %d job info from GCS.", len(jobs))
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to GetAllJobInfo: {reply.status.message}")
|
||||
except Exception:
|
||||
logger.exception("Error Getting all job info from GCS.")
|
||||
await asyncio.sleep(
|
||||
job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS)
|
||||
stdout: bytes = self._job_manager.get_job_stdout(logs_request.job_id)
|
||||
stderr: bytes = self._job_manager.get_job_stderr(logs_request.job_id)
|
||||
|
||||
# Receive jobs from channel.
|
||||
async for sender, msg in receiver.iter():
|
||||
try:
|
||||
_, data = msg
|
||||
pubsub_message = gcs_utils.PubSubMessage.FromString(data)
|
||||
message = gcs_utils.JobTableData.FromString(
|
||||
pubsub_message.data)
|
||||
job_table_data = job_table_data_to_dict(message)
|
||||
job_id = job_table_data["jobId"]
|
||||
# Update jobs.
|
||||
DataSource.jobs[job_id] = job_table_data
|
||||
except Exception:
|
||||
logger.exception("Error receiving job info.")
|
||||
# TODO(jiaodong): Support log streaming #19415
|
||||
resp = JobLogsResponse(
|
||||
stdout=stdout.decode("utf-8"), stderr=stderr.decode("utf-8"))
|
||||
|
||||
return dashboard_utils.rest_response(
|
||||
success=True,
|
||||
convert_google_style=False,
|
||||
data=resp.dict(),
|
||||
message=f"Logs returned for job {logs_request.job_id}")
|
||||
|
||||
async def run(self, server):
|
||||
self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
|
||||
self._dashboard_head.aiogrpc_gcs_channel)
|
||||
|
||||
await asyncio.gather(self._update_jobs())
|
||||
if not self._job_manager:
|
||||
self._job_manager = JobManager()
|
||||
|
|
23
dashboard/modules/job/tests/conftest.py
Normal file
23
dashboard/modules/job/tests/conftest.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def local_working_dir():
|
||||
yield {
|
||||
"runtime_env": {},
|
||||
"entrypoint": "echo hello",
|
||||
"expected_stdout": "hello",
|
||||
"expected_stderr": ""
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def s3_working_dir():
|
||||
yield {
|
||||
"runtime_env": {
|
||||
"working_dir": "s3://runtime-env-test/script.zip",
|
||||
},
|
||||
"entrypoint": "python script.py",
|
||||
"expected_stdout": "Executing main() from script.py !!",
|
||||
"expected_stderr": ""
|
||||
}
|
64
dashboard/modules/job/tests/test_http_job_server.py
Normal file
64
dashboard/modules/job/tests/test_http_job_server.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import sys
|
||||
|
||||
import logging
|
||||
import requests
|
||||
from uuid import uuid4
|
||||
import pytest
|
||||
from pytest_lazyfixture import lazy_fixture
|
||||
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray._private.test_utils import (format_web_url,
|
||||
wait_until_server_available)
|
||||
from ray._private.job_manager import JobStatus
|
||||
from ray.dashboard.modules.job.data_types import (
|
||||
JobSubmitRequest, JobSubmitResponse, JobStatusRequest, JobStatusResponse,
|
||||
JobLogsRequest, JobLogsResponse, JobSpec)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _setup_webui_url(ray_start_with_dashboard):
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
assert wait_until_server_available(webui_url)
|
||||
webui_url = format_web_url(webui_url)
|
||||
return webui_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"working_dir",
|
||||
[lazy_fixture("local_working_dir"),
|
||||
lazy_fixture("s3_working_dir")])
|
||||
def test_submit_job(disable_aiohttp_cache, enable_test_module,
|
||||
ray_start_with_dashboard, working_dir):
|
||||
webui_url = _setup_webui_url(ray_start_with_dashboard)
|
||||
|
||||
job_spec = JobSpec(
|
||||
runtime_env=working_dir["runtime_env"],
|
||||
entrypoint=working_dir["entrypoint"],
|
||||
metadata=dict())
|
||||
submit_request = JobSubmitRequest(job_spec=job_spec, job_id=str(uuid4()))
|
||||
|
||||
resp = requests.post(f"{webui_url}/submit", json=submit_request.dict())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()["data"]["data"]
|
||||
response = JobSubmitResponse(**data)
|
||||
assert response.job_id == submit_request.job_id
|
||||
|
||||
status_request = JobStatusRequest(job_id=submit_request.job_id)
|
||||
resp = requests.get(f"{webui_url}/status", json=status_request.dict())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()["data"]["data"]
|
||||
response = JobStatusResponse(**data)
|
||||
assert response.job_status == JobStatus.SUCCEEDED
|
||||
|
||||
logs_request = JobLogsRequest(job_id=submit_request.job_id)
|
||||
resp = requests.get(f"{webui_url}/logs", json=logs_request.dict())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()["data"]["data"]
|
||||
response = JobLogsResponse(**data)
|
||||
assert response.stdout == working_dir["expected_stdout"]
|
||||
assert response.stderr == working_dir["expected_stderr"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,330 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import tempfile
|
||||
import zipfile
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray._private.utils import hex_to_binary
|
||||
from ray.dashboard.tests.conftest import * # noqa
|
||||
from ray.dashboard.modules.job import job_consts
|
||||
from ray._private.test_utils import (
|
||||
format_web_url,
|
||||
wait_until_server_available,
|
||||
wait_for_condition,
|
||||
wait_until_succeeded_without_exception,
|
||||
)
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEST_PYTHON_JOB = {
|
||||
"language": job_consts.PYTHON,
|
||||
"runtime_env": {
|
||||
"working_dir": "{web_url}/test/file?path={path}"
|
||||
},
|
||||
"driver_entry": "simple_job",
|
||||
}
|
||||
|
||||
TEST_PYTHON_JOB_CODE = """
|
||||
import os
|
||||
import sys
|
||||
import ray
|
||||
import time
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def __init__(self, index):
|
||||
self._index = index
|
||||
|
||||
def foo(self, x):
|
||||
return f"Actor {self._index}: {x}"
|
||||
|
||||
|
||||
def main():
|
||||
actors = []
|
||||
for x in range(2):
|
||||
actors.append(Actor.remote(x))
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
for a in actors:
|
||||
r = a.foo.remote(counter)
|
||||
print(ray.get(r))
|
||||
counter += 1
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
main()
|
||||
"""
|
||||
|
||||
|
||||
def _gen_job_zip(job_code, driver_entry):
|
||||
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as f:
|
||||
with zipfile.ZipFile(f, mode="w") as zip_f:
|
||||
with zip_f.open(f"{driver_entry}.py", "w") as driver:
|
||||
driver.write(job_code.encode())
|
||||
return f.name
|
||||
|
||||
|
||||
def _prepare_job_for_test(web_url):
|
||||
path = _gen_job_zip(TEST_PYTHON_JOB_CODE, TEST_PYTHON_JOB["driver_entry"])
|
||||
job = copy.deepcopy(TEST_PYTHON_JOB)
|
||||
job["runtime_env"]["working_dir"] = job["runtime_env"][
|
||||
"working_dir"].format(
|
||||
web_url=web_url, path=path)
|
||||
return job
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_with_dashboard", [{
|
||||
"job_config": ray.job_config.JobConfig(code_search_path=[""]),
|
||||
}],
|
||||
indirect=True)
|
||||
def test_submit_job(disable_aiohttp_cache, enable_test_module,
|
||||
ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
job = _prepare_job_for_test(webui_url)
|
||||
job_root_dir = os.path.join(
|
||||
os.path.dirname(ray_start_with_dashboard["session_dir"]), "job")
|
||||
shutil.rmtree(job_root_dir, ignore_errors=True)
|
||||
|
||||
job_id = None
|
||||
job_submitted = False
|
||||
|
||||
def _check_running():
|
||||
nonlocal job_id
|
||||
nonlocal job_submitted
|
||||
if not job_submitted:
|
||||
resp = requests.post(f"{webui_url}/jobs", json=job)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True, resp.text
|
||||
job_submitted = True
|
||||
|
||||
resp = requests.get(f"{webui_url}/jobs?view=summary")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True, resp.text
|
||||
summary = result["data"]["summary"]
|
||||
assert len(summary) == 2
|
||||
|
||||
# TODO(fyrestone): Return a job id when POST /jobs
|
||||
# The larger job id is the one we submitted.
|
||||
job_ids = sorted(s["jobId"] for s in summary)
|
||||
job_id = job_ids[1]
|
||||
|
||||
resp = requests.get(f"{webui_url}/jobs/{job_id}")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True, resp.text
|
||||
job_info = result["data"]["detail"]["jobInfo"]
|
||||
assert job_info["jobId"] == job_id
|
||||
|
||||
resp = requests.get(f"{webui_url}/jobs/{job_id}")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True, resp.text
|
||||
job_info = result["data"]["detail"]["jobInfo"]
|
||||
assert job_info["isDead"] is False
|
||||
job_actors = result["data"]["detail"]["jobActors"]
|
||||
job_workers = result["data"]["detail"]["jobWorkers"]
|
||||
assert len(job_actors) > 0
|
||||
assert len(job_workers) > 0
|
||||
|
||||
wait_until_succeeded_without_exception(
|
||||
_check_running,
|
||||
exceptions=(AssertionError, KeyError, IndexError),
|
||||
timeout_ms=30 * 1000,
|
||||
raise_last_ex=True)
|
||||
|
||||
|
||||
def test_get_job_info(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def getpid(self):
|
||||
return os.getpid()
|
||||
|
||||
actor = Actor.remote()
|
||||
actor_pid = ray.get(actor.getpid.remote())
|
||||
actor_id = actor._actor_id.hex()
|
||||
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
ip = ray.util.get_node_ip_address()
|
||||
|
||||
def _check():
|
||||
resp = requests.get(f"{webui_url}/jobs?view=summary")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True, resp.text
|
||||
job_summary = result["data"]["summary"]
|
||||
assert len(job_summary) == 1, resp.text
|
||||
one_job = job_summary[0]
|
||||
assert "jobId" in one_job
|
||||
job_id = one_job["jobId"]
|
||||
assert ray._raylet.JobID(hex_to_binary(one_job["jobId"]))
|
||||
assert "driverIpAddress" in one_job
|
||||
assert one_job["driverIpAddress"] == ip
|
||||
assert "driverPid" in one_job
|
||||
assert one_job["driverPid"] == str(os.getpid())
|
||||
assert "config" in one_job
|
||||
assert type(one_job["config"]) is dict
|
||||
assert "isDead" in one_job
|
||||
assert one_job["isDead"] is False
|
||||
assert "timestamp" in one_job
|
||||
one_job_summary_keys = one_job.keys()
|
||||
|
||||
resp = requests.get(f"{webui_url}/jobs/{job_id}")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is True, resp.text
|
||||
job_detail = result["data"]["detail"]
|
||||
assert "jobInfo" in job_detail
|
||||
assert len(one_job_summary_keys - job_detail["jobInfo"].keys()) == 0
|
||||
assert "jobActors" in job_detail
|
||||
job_actors = job_detail["jobActors"]
|
||||
assert len(job_actors) == 1, resp.text
|
||||
one_job_actor = job_actors[actor_id]
|
||||
assert "taskSpec" in one_job_actor
|
||||
assert type(one_job_actor["taskSpec"]) is dict
|
||||
assert "functionDescriptor" in one_job_actor["taskSpec"]
|
||||
assert type(one_job_actor["taskSpec"]["functionDescriptor"]) is dict
|
||||
assert "pid" in one_job_actor
|
||||
assert one_job_actor["pid"] == actor_pid
|
||||
check_actor_keys = [
|
||||
"name", "timestamp", "address", "actorId", "jobId", "state"
|
||||
]
|
||||
for k in check_actor_keys:
|
||||
assert k in one_job_actor
|
||||
assert "jobWorkers" in job_detail
|
||||
job_workers = job_detail["jobWorkers"]
|
||||
assert len(job_workers) == 1, resp.text
|
||||
one_job_worker = job_workers[0]
|
||||
check_worker_keys = [
|
||||
"cmdline", "pid", "cpuTimes", "memoryInfo", "cpuPercent",
|
||||
"coreWorkerStats", "language", "jobId"
|
||||
]
|
||||
for k in check_worker_keys:
|
||||
assert k in one_job_worker
|
||||
|
||||
timeout_seconds = 30
|
||||
start_time = time.time()
|
||||
last_ex = None
|
||||
while True:
|
||||
time.sleep(5)
|
||||
try:
|
||||
_check()
|
||||
break
|
||||
except (AssertionError, KeyError, IndexError) as ex:
|
||||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
||||
def test_submit_job_validation(ray_start_with_dashboard):
|
||||
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
|
||||
is True)
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
job_root_dir = os.path.join(
|
||||
os.path.dirname(ray_start_with_dashboard["session_dir"]), "job")
|
||||
shutil.rmtree(job_root_dir, ignore_errors=True)
|
||||
|
||||
def _ensure_available_nodes():
|
||||
resp = requests.post(f"{webui_url}/jobs")
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is False
|
||||
return "no nodes available" not in result["msg"]
|
||||
|
||||
wait_for_condition(_ensure_available_nodes, timeout=5)
|
||||
|
||||
# Invalid value.
|
||||
resp = requests.post(
|
||||
f"{webui_url}/jobs",
|
||||
json={
|
||||
"language": "Unsupported",
|
||||
"runtime_env": {
|
||||
"working_dir": "http://xxx/yyy.zip"
|
||||
},
|
||||
"driver_entry": "python_file_name_without_ext",
|
||||
})
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is False
|
||||
msg = result["msg"]
|
||||
assert "language" in msg and "Unsupported" in msg, resp.text
|
||||
|
||||
# Missing required field.
|
||||
resp = requests.post(
|
||||
f"{webui_url}/jobs",
|
||||
json={
|
||||
"language": job_consts.PYTHON,
|
||||
"runtime_env": {
|
||||
"working_dir": "http://xxx/yyy.zip"
|
||||
},
|
||||
})
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is False
|
||||
msg = result["msg"]
|
||||
assert all(p in msg for p in ["missing", "driver_entry"]), resp.text
|
||||
|
||||
# Incorrect value type.
|
||||
resp = requests.post(
|
||||
f"{webui_url}/jobs",
|
||||
json={
|
||||
"language": job_consts.PYTHON,
|
||||
"runtime_env": {
|
||||
"working_dir": ["http://xxx/yyy.zip"]
|
||||
},
|
||||
"driver_entry": "python_file_name_without_ext",
|
||||
})
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is False
|
||||
msg = result["msg"]
|
||||
assert all(p in msg for p in ["working_dir", "str"]), resp.text
|
||||
|
||||
# Invalid key.
|
||||
resp = requests.post(
|
||||
f"{webui_url}/jobs",
|
||||
json={
|
||||
"language": job_consts.PYTHON,
|
||||
"runtime_env": {
|
||||
"working_dir": "http://xxx/yyy.zip"
|
||||
},
|
||||
"driver_entry": "python_file_name_without_ext",
|
||||
"invalid_key": 1,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
assert result["result"] is False
|
||||
msg = result["msg"]
|
||||
assert all(p in msg for p in ["unexpected", "invalid_key"]), resp.text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -231,7 +231,8 @@ class CustomEncoder(json.JSONEncoder):
|
|||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def rest_response(success, message, **kwargs) -> aiohttp.web.Response:
|
||||
def rest_response(success, message, convert_google_style=True,
|
||||
**kwargs) -> aiohttp.web.Response:
|
||||
# In the dev context we allow a dev server running on a
|
||||
# different port to consume the API, meaning we need to allow
|
||||
# cross-origin access
|
||||
|
@ -243,7 +244,7 @@ def rest_response(success, message, **kwargs) -> aiohttp.web.Response:
|
|||
{
|
||||
"result": success,
|
||||
"msg": message,
|
||||
"data": to_google_style(kwargs)
|
||||
"data": to_google_style(kwargs) if convert_google_style else kwargs
|
||||
},
|
||||
dumps=functools.partial(json.dumps, cls=CustomEncoder),
|
||||
headers=headers)
|
||||
|
|
|
@ -2,7 +2,6 @@ import subprocess
|
|||
import pickle
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Optional
|
||||
from enum import Enum
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
|
@ -13,14 +12,7 @@ from ray.experimental.internal_kv import (
|
|||
_internal_kv_get,
|
||||
_internal_kv_put,
|
||||
)
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
STOPPED = "STOPPED"
|
||||
SUCCEEDED = "SUCCEEDED"
|
||||
FAILED = "FAILED"
|
||||
from ray.dashboard.modules.job.data_types import JobStatus
|
||||
|
||||
|
||||
class JobLogStorageClient:
|
||||
|
@ -172,6 +164,7 @@ class JobManager:
|
|||
self._status_client = JobStatusStorageClient()
|
||||
self._log_client = JobLogStorageClient()
|
||||
self._supervisor_actor_cls = ray.remote(JobSupervisor)
|
||||
|
||||
assert _internal_kv_initialized()
|
||||
|
||||
def _get_actor_for_job(self, job_id: str) -> Optional[ActorHandle]:
|
||||
|
@ -217,19 +210,20 @@ class JobManager:
|
|||
|
||||
def stop_job(self, job_id) -> bool:
|
||||
"""Request job to exit."""
|
||||
a = self._get_actor_for_job(job_id)
|
||||
if a is not None:
|
||||
job_supervisor_actor = self._get_actor_for_job(job_id)
|
||||
if job_supervisor_actor is not None:
|
||||
# Actor is still alive, signal it to stop the driver.
|
||||
a.stop.remote()
|
||||
job_supervisor_actor.stop.remote()
|
||||
|
||||
def get_job_status(self, job_id: str):
|
||||
a = self._get_actor_for_job(job_id)
|
||||
job_supervisor_actor = self._get_actor_for_job(job_id)
|
||||
# Actor is still alive, try to get status from it.
|
||||
try:
|
||||
return ray.get(a.get_status.remote())
|
||||
except RayActorError:
|
||||
# Actor exited, so we should fall back to internal_kv.
|
||||
pass
|
||||
if job_supervisor_actor is not None:
|
||||
try:
|
||||
return ray.get(job_supervisor_actor.get_status.remote())
|
||||
except RayActorError:
|
||||
# Actor exited, so we should fall back to internal_kv.
|
||||
pass
|
||||
|
||||
# Fall back to storage if the actor is dead.
|
||||
return self._status_client.get_status(job_id)
|
||||
|
|
Loading…
Add table
Reference in a new issue