mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
315 lines
12 KiB
Python
315 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os.path
|
|
import itertools
|
|
import subprocess
|
|
import sys
|
|
import secrets
|
|
import uuid
|
|
import traceback
|
|
from abc import abstractmethod
|
|
from typing import Union, Any
|
|
|
|
import ray.new_dashboard.utils as dashboard_utils
|
|
from ray.new_dashboard.utils import create_task
|
|
from ray.new_dashboard.modules.job import job_consts
|
|
from ray.new_dashboard.modules.job.job_description import JobDescription
|
|
from ray.core.generated import job_agent_pb2
|
|
from ray.core.generated import job_agent_pb2_grpc
|
|
from ray.core.generated import agent_manager_pb2
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JobInfo(JobDescription):
|
|
# TODO(fyrestone): We should use job id instead of unique id.
|
|
unique_id: str
|
|
# The temp directory.
|
|
temp_dir: str
|
|
# The log directory.
|
|
log_dir: str
|
|
# The driver process instance.
|
|
driver: Union[None, asyncio.subprocess.Process]
|
|
|
|
def __init__(self, **data: Any):
|
|
super().__init__(**data)
|
|
# Support json values for env.
|
|
self.env = {
|
|
k: v if isinstance(v, str) else json.dumps(v)
|
|
for k, v in self.env.items()
|
|
}
|
|
|
|
|
|
class JobProcessor:
|
|
"""Wraps the job info and provides common utils to download packages,
|
|
start drivers, etc.
|
|
|
|
Args:
|
|
job_info (JobInfo): The job info.
|
|
"""
|
|
_cmd_index_gen = itertools.count(1)
|
|
|
|
def __init__(self, job_info):
|
|
assert isinstance(job_info, JobInfo)
|
|
self._job_info = job_info
|
|
|
|
async def _download_package(self, http_session, url, filename):
|
|
unique_id = self._job_info.unique_id
|
|
cmd_index = next(self._cmd_index_gen)
|
|
logger.info("[%s] Start download[%s] %s to %s", unique_id, cmd_index,
|
|
url, filename)
|
|
async with http_session.get(url, ssl=False) as response:
|
|
with open(filename, "wb") as f:
|
|
while True:
|
|
chunk = await response.content.read(
|
|
job_consts.DOWNLOAD_BUFFER_SIZE)
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
logger.info("[%s] Finished download[%s] %s to %s", unique_id,
|
|
cmd_index, url, filename)
|
|
|
|
async def _unpack_package(self, filename, path):
|
|
code = f"import shutil; " \
|
|
f"shutil.unpack_archive({repr(filename)}, {repr(path)})"
|
|
unzip_cmd = [self._get_current_python(), "-c", code]
|
|
await self._check_output_cmd(unzip_cmd)
|
|
|
|
async def _check_output_cmd(self, cmd):
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE)
|
|
unique_id = self._job_info.unique_id
|
|
cmd_index = next(self._cmd_index_gen)
|
|
proc.cmd_index = cmd_index
|
|
logger.info("[%s] Run cmd[%s] %s", unique_id, cmd_index, repr(cmd))
|
|
stdout, stderr = await proc.communicate()
|
|
stdout = stdout.decode("utf-8")
|
|
logger.info("[%s] Output of cmd[%s]: %s", unique_id, cmd_index, stdout)
|
|
if proc.returncode != 0:
|
|
stderr = stderr.decode("utf-8")
|
|
logger.error("[%s] Error of cmd[%s]: %s", unique_id, cmd_index,
|
|
stderr)
|
|
raise subprocess.CalledProcessError(
|
|
proc.returncode, cmd, output=stdout, stderr=stderr)
|
|
return stdout
|
|
|
|
async def _start_driver(self, cmd, stdout, stderr, env):
|
|
unique_id = self._job_info.unique_id
|
|
job_package_dir = job_consts.JOB_UNPACK_DIR.format(
|
|
temp_dir=self._job_info.temp_dir, unique_id=unique_id)
|
|
cmd_str = subprocess.list2cmdline(cmd)
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdout=stdout,
|
|
stderr=stderr,
|
|
env={
|
|
**os.environ,
|
|
**env,
|
|
},
|
|
cwd=job_package_dir,
|
|
)
|
|
logger.info("[%s] Start driver cmd %s with pid %s", unique_id,
|
|
repr(cmd_str), proc.pid)
|
|
return proc
|
|
|
|
@staticmethod
|
|
def _get_current_python():
|
|
return sys.executable
|
|
|
|
@staticmethod
|
|
def _new_log_files(log_dir, filename):
|
|
if log_dir is None:
|
|
return None, None
|
|
stdout = open(
|
|
os.path.join(log_dir, filename + ".out"), "a", buffering=1)
|
|
stderr = open(
|
|
os.path.join(log_dir, filename + ".err"), "a", buffering=1)
|
|
return stdout, stderr
|
|
|
|
@abstractmethod
|
|
async def run(self):
|
|
pass
|
|
|
|
|
|
class DownloadPackage(JobProcessor):
|
|
""" Download the job package.
|
|
|
|
Args:
|
|
job_info (JobInfo): The job info.
|
|
http_session (aiohttp.ClientSession): The client session.
|
|
"""
|
|
|
|
def __init__(self, job_info, http_session):
|
|
super().__init__(job_info)
|
|
self._http_session = http_session
|
|
|
|
async def run(self):
|
|
temp_dir = self._job_info.temp_dir
|
|
unique_id = self._job_info.unique_id
|
|
filename = job_consts.DOWNLOAD_PACKAGE_FILE.format(
|
|
temp_dir=temp_dir, unique_id=unique_id)
|
|
unpack_dir = job_consts.JOB_UNPACK_DIR.format(
|
|
temp_dir=temp_dir, unique_id=unique_id)
|
|
url = self._job_info.runtime_env.working_dir
|
|
await self._download_package(self._http_session, url, filename)
|
|
await self._unpack_package(filename, unpack_dir)
|
|
|
|
|
|
class StartPythonDriver(JobProcessor):
|
|
""" Start the driver for Python job.
|
|
|
|
Args:
|
|
job_info (JobInfo): The job info.
|
|
redis_address (tuple): The (ip, port) of redis.
|
|
redis_password (str): The password of redis.
|
|
"""
|
|
|
|
_template = """import sys
|
|
sys.path.append({import_path})
|
|
import ray
|
|
from ray._private.utils import hex_to_binary
|
|
ray.init(ignore_reinit_error=True,
|
|
address={redis_address},
|
|
_redis_password={redis_password},
|
|
job_config=ray.job_config.JobConfig({job_config_args}),
|
|
)
|
|
import {driver_entry}
|
|
{driver_entry}.main({driver_args})
|
|
# If the driver exits normally, we invoke Ray.shutdown() again
|
|
# here, in case the user code forgot to invoke it.
|
|
ray.shutdown()
|
|
"""
|
|
|
|
def __init__(self, job_info, redis_address, redis_password):
|
|
super().__init__(job_info)
|
|
self._redis_address = redis_address
|
|
self._redis_password = redis_password
|
|
|
|
def _gen_driver_code(self):
|
|
temp_dir = self._job_info.temp_dir
|
|
unique_id = self._job_info.unique_id
|
|
job_package_dir = job_consts.JOB_UNPACK_DIR.format(
|
|
temp_dir=temp_dir, unique_id=unique_id)
|
|
driver_entry_file = job_consts.JOB_DRIVER_ENTRY_FILE.format(
|
|
temp_dir=temp_dir, unique_id=unique_id, uuid=uuid.uuid4())
|
|
ip, port = self._redis_address
|
|
|
|
# Per job config
|
|
job_config_items = {
|
|
"worker_env": self._job_info.env,
|
|
"code_search_path": [job_package_dir],
|
|
}
|
|
|
|
job_config_args = ", ".join(f"{key}={repr(value)}"
|
|
for key, value in job_config_items.items()
|
|
if value is not None)
|
|
driver_args = ", ".join([repr(x) for x in self._job_info.driver_args])
|
|
driver_code = self._template.format(
|
|
job_config_args=job_config_args,
|
|
import_path=repr(job_package_dir),
|
|
redis_address=repr(ip + ":" + str(port)),
|
|
redis_password=repr(self._redis_password),
|
|
driver_entry=self._job_info.driver_entry,
|
|
driver_args=driver_args)
|
|
with open(driver_entry_file, "w") as fp:
|
|
fp.write(driver_code)
|
|
return driver_entry_file
|
|
|
|
async def run(self):
|
|
python = self._get_current_python()
|
|
driver_file = self._gen_driver_code()
|
|
driver_cmd = [python, "-u", driver_file]
|
|
stdout_file, stderr_file = self._new_log_files(
|
|
self._job_info.log_dir, f"driver-{self._job_info.unique_id}")
|
|
return await self._start_driver(driver_cmd, stdout_file, stderr_file,
|
|
self._job_info.env)
|
|
|
|
|
|
class JobAgent(dashboard_utils.DashboardAgentModule,
|
|
job_agent_pb2_grpc.JobAgentServiceServicer):
|
|
""" The JobAgentService defined in job_agent.proto for initializing /
|
|
cleaning job environments.
|
|
"""
|
|
|
|
async def InitializeJobEnv(self, request, context):
|
|
# TODO(fyrestone): Handle duplicated InitializeJobEnv requests
|
|
# when initializing job environment.
|
|
# TODO(fyrestone): Support reinitialize job environment.
|
|
|
|
# TODO(fyrestone): Use job id instead of unique id.
|
|
unique_id = secrets.token_hex(6)
|
|
|
|
# Parse the job description from the request.
|
|
try:
|
|
job_description_data = json.loads(request.job_description)
|
|
job_info = JobInfo(
|
|
unique_id=unique_id,
|
|
temp_dir=self._dashboard_agent.temp_dir,
|
|
log_dir=self._dashboard_agent.log_dir,
|
|
**job_description_data)
|
|
except json.JSONDecodeError as ex:
|
|
error_message = str(ex)
|
|
error_message += f", job_payload:\n{request.job_description}"
|
|
logger.error("[%s] Initialize job environment failed, %s.",
|
|
unique_id, error_message)
|
|
return job_agent_pb2.InitializeJobEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message=error_message)
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
return job_agent_pb2.InitializeJobEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message=traceback.format_exc())
|
|
|
|
async def _initialize_job_env():
|
|
os.makedirs(
|
|
job_consts.JOB_DIR.format(
|
|
temp_dir=job_info.temp_dir, unique_id=unique_id),
|
|
exist_ok=True)
|
|
# Download the job package.
|
|
await DownloadPackage(job_info,
|
|
self._dashboard_agent.http_session).run()
|
|
# Start the driver.
|
|
logger.info("[%s] Starting driver.", unique_id)
|
|
language = job_info.language
|
|
if language == job_consts.PYTHON:
|
|
driver = await StartPythonDriver(
|
|
job_info, self._dashboard_agent.redis_address,
|
|
self._dashboard_agent.redis_password).run()
|
|
else:
|
|
raise Exception(f"Unsupported language type: {language}")
|
|
job_info.driver = driver
|
|
|
|
initialize_task = create_task(_initialize_job_env())
|
|
|
|
try:
|
|
await initialize_task
|
|
except asyncio.CancelledError:
|
|
logger.error("[%s] Initialize job environment has been cancelled.",
|
|
unique_id)
|
|
return job_agent_pb2.InitializeJobEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message="InitializeJobEnv has been cancelled, "
|
|
"did you call CleanJobEnv?")
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
return job_agent_pb2.InitializeJobEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message=traceback.format_exc())
|
|
|
|
driver_pid = 0
|
|
if job_info.driver:
|
|
driver_pid = job_info.driver.pid
|
|
|
|
logger.info(
|
|
"[%s] Job environment initialized, "
|
|
"the driver (pid=%s) started.", unique_id, driver_pid)
|
|
return job_agent_pb2.InitializeJobEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
|
driver_pid=driver_pid)
|
|
|
|
async def run(self, server):
|
|
job_agent_pb2_grpc.add_JobAgentServiceServicer_to_server(self, server)
|