mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
550 lines
23 KiB
Python
550 lines
23 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import traceback
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Dict, List, Set, Tuple
|
|
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
|
|
|
|
import ray.dashboard.consts as dashboard_consts
|
|
import ray.dashboard.modules.runtime_env.runtime_env_consts as runtime_env_consts
|
|
import ray.dashboard.utils as dashboard_utils
|
|
from ray._private.async_compat import create_task
|
|
from ray._private.ray_logging import setup_component_logger
|
|
from ray._private.runtime_env.conda import CondaPlugin
|
|
from ray._private.runtime_env.container import ContainerManager
|
|
from ray._private.runtime_env.context import RuntimeEnvContext
|
|
from ray._private.runtime_env.java_jars import JavaJarsPlugin
|
|
from ray._private.runtime_env.pip import PipPlugin
|
|
from ray._private.runtime_env.plugin import (
|
|
RuntimeEnvPlugin,
|
|
create_for_plugin_if_needed,
|
|
)
|
|
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
|
|
from ray._private.runtime_env.py_modules import PyModulesPlugin
|
|
from ray._private.runtime_env.working_dir import WorkingDirPlugin
|
|
from ray.core.generated import (
|
|
agent_manager_pb2,
|
|
runtime_env_agent_pb2,
|
|
runtime_env_agent_pb2_grpc,
|
|
)
|
|
from ray.core.generated.runtime_env_common_pb2 import (
|
|
RuntimeEnvState as ProtoRuntimeEnvState,
|
|
)
|
|
from ray.runtime_env import RuntimeEnv, RuntimeEnvConfig
|
|
|
|
default_logger = logging.getLogger(__name__)
|
|
|
|
# TODO(edoakes): this is used for unit tests. We should replace it with a
|
|
# better pluggability mechanism once available.
|
|
SLEEP_FOR_TESTING_S = os.environ.get("RAY_RUNTIME_ENV_SLEEP_FOR_TESTING_S")
|
|
|
|
|
|
@dataclass
|
|
class CreatedEnvResult:
|
|
# Whether or not the env was installed correctly.
|
|
success: bool
|
|
# If success is True, will be a serialized RuntimeEnvContext
|
|
# If success is False, will be an error message.
|
|
result: str
|
|
# The time to create a runtime env in ms.
|
|
creation_time_ms: int
|
|
|
|
|
|
# e.g., "working_dir"
|
|
UriType = str
|
|
|
|
|
|
class ReferenceTable:
|
|
"""
|
|
The URI reference table which is used for GC.
|
|
When the reference count is decreased to zero,
|
|
the URI should be removed from this table and
|
|
added to cache if needed.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
uris_parser: Callable[[RuntimeEnv], Tuple[str, UriType]],
|
|
unused_uris_callback: Callable[[List[Tuple[str, UriType]]], None],
|
|
unused_runtime_env_callback: Callable[[str], None],
|
|
):
|
|
# Runtime Environment reference table. The key is serialized runtime env and
|
|
# the value is reference count.
|
|
self._runtime_env_reference: Dict[str, int] = defaultdict(int)
|
|
# URI reference table. The key is URI parsed from runtime env and the value
|
|
# is reference count.
|
|
self._uri_reference: Dict[str, int] = defaultdict(int)
|
|
self._uris_parser = uris_parser
|
|
self._unused_uris_callback = unused_uris_callback
|
|
self._unused_runtime_env_callback = unused_runtime_env_callback
|
|
# send the `DeleteRuntimeEnvIfPossible` RPC when the client exits. The URI won't
|
|
# be leaked now because the reference count will be reset to zero when the job
|
|
# finished.
|
|
self._reference_exclude_sources: Set[str] = {
|
|
"client_server",
|
|
}
|
|
|
|
def _increase_reference_for_uris(self, uris):
|
|
default_logger.debug(f"Increase reference for uris {uris}.")
|
|
for uri, _ in uris:
|
|
self._uri_reference[uri] += 1
|
|
|
|
def _decrease_reference_for_uris(self, uris):
|
|
default_logger.debug(f"Decrease reference for uris {uris}.")
|
|
unused_uris = list()
|
|
for uri, uri_type in uris:
|
|
if self._uri_reference[uri] > 0:
|
|
self._uri_reference[uri] -= 1
|
|
if self._uri_reference[uri] == 0:
|
|
unused_uris.append((uri, uri_type))
|
|
del self._uri_reference[uri]
|
|
else:
|
|
default_logger.warn(f"URI {uri} does not exist.")
|
|
if unused_uris:
|
|
default_logger.info(f"Unused uris {unused_uris}.")
|
|
self._unused_uris_callback(unused_uris)
|
|
return unused_uris
|
|
|
|
def _increase_reference_for_runtime_env(self, serialized_env: str):
|
|
default_logger.debug(f"Increase reference for runtime env {serialized_env}.")
|
|
self._runtime_env_reference[serialized_env] += 1
|
|
|
|
def _decrease_reference_for_runtime_env(self, serialized_env: str):
|
|
default_logger.debug(f"Decrease reference for runtime env {serialized_env}.")
|
|
unused = False
|
|
if self._runtime_env_reference[serialized_env] > 0:
|
|
self._runtime_env_reference[serialized_env] -= 1
|
|
if self._runtime_env_reference[serialized_env] == 0:
|
|
unused = True
|
|
del self._runtime_env_reference[serialized_env]
|
|
else:
|
|
default_logger.warn(f"Runtime env {serialized_env} does not exist.")
|
|
if unused:
|
|
default_logger.info(f"Unused runtime env {serialized_env}.")
|
|
self._unused_runtime_env_callback(serialized_env)
|
|
return unused
|
|
|
|
def increase_reference(
|
|
self, runtime_env: RuntimeEnv, serialized_env: str, source_process: str
|
|
) -> None:
|
|
if source_process in self._reference_exclude_sources:
|
|
return
|
|
self._increase_reference_for_runtime_env(serialized_env)
|
|
uris = self._uris_parser(runtime_env)
|
|
self._increase_reference_for_uris(uris)
|
|
|
|
def decrease_reference(
|
|
self, runtime_env: RuntimeEnv, serialized_env: str, source_process: str
|
|
) -> None:
|
|
if source_process in self._reference_exclude_sources:
|
|
return list()
|
|
self._decrease_reference_for_runtime_env(serialized_env)
|
|
uris = self._uris_parser(runtime_env)
|
|
self._decrease_reference_for_uris(uris)
|
|
|
|
@property
|
|
def runtime_env_refs(self) -> Dict[str, int]:
|
|
"""Return the runtime_env -> ref count mapping.
|
|
|
|
Returns:
|
|
The mapping of serialized runtime env -> ref count.
|
|
"""
|
|
return self._runtime_env_reference
|
|
|
|
|
|
class RuntimeEnvAgent(
|
|
dashboard_utils.DashboardAgentModule,
|
|
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer,
|
|
):
|
|
"""An RPC server to create and delete runtime envs.
|
|
|
|
Attributes:
|
|
dashboard_agent: The DashboardAgent object contains global config.
|
|
"""
|
|
|
|
def __init__(self, dashboard_agent):
|
|
super().__init__(dashboard_agent)
|
|
self._runtime_env_dir = dashboard_agent.runtime_env_dir
|
|
self._logging_params = dashboard_agent.logging_params
|
|
self._per_job_logger_cache = dict()
|
|
# Cache the results of creating envs to avoid repeatedly calling into
|
|
# conda and other slow calls.
|
|
self._env_cache: Dict[str, CreatedEnvResult] = dict()
|
|
# Maps a serialized runtime env to a lock that is used
|
|
# to prevent multiple concurrent installs of the same env.
|
|
self._env_locks: Dict[str, asyncio.Lock] = dict()
|
|
self._gcs_aio_client = self._dashboard_agent.gcs_aio_client
|
|
|
|
self._pip_plugin = PipPlugin(self._runtime_env_dir)
|
|
self._conda_plugin = CondaPlugin(self._runtime_env_dir)
|
|
self._py_modules_plugin = PyModulesPlugin(
|
|
self._runtime_env_dir, self._gcs_aio_client
|
|
)
|
|
self._java_jars_plugin = JavaJarsPlugin(
|
|
self._runtime_env_dir, self._gcs_aio_client
|
|
)
|
|
self._working_dir_plugin = WorkingDirPlugin(
|
|
self._runtime_env_dir, self._gcs_aio_client
|
|
)
|
|
self._container_manager = ContainerManager(dashboard_agent.temp_dir)
|
|
|
|
# TODO(architkulkarni): "base plugins" and third-party plugins should all go
|
|
# through the same code path. We should never need to refer to
|
|
# self._xxx_plugin, we should just iterate through self._plugins.
|
|
self._base_plugins: List[RuntimeEnvPlugin] = [
|
|
self._working_dir_plugin,
|
|
self._pip_plugin,
|
|
self._conda_plugin,
|
|
self._py_modules_plugin,
|
|
self._java_jars_plugin,
|
|
]
|
|
self._plugin_manager = RuntimeEnvPluginManager()
|
|
for plugin in self._base_plugins:
|
|
self._plugin_manager.add_plugin(plugin)
|
|
|
|
self._reference_table = ReferenceTable(
|
|
self.uris_parser,
|
|
self.unused_uris_processor,
|
|
self.unused_runtime_env_processor,
|
|
)
|
|
|
|
self._logger = default_logger
|
|
|
|
def uris_parser(self, runtime_env):
|
|
result = list()
|
|
for name, plugin_setup_context in self._plugin_manager.plugins.items():
|
|
plugin = plugin_setup_context.class_instance
|
|
uris = plugin.get_uris(runtime_env)
|
|
for uri in uris:
|
|
result.append((uri, UriType(name)))
|
|
return result
|
|
|
|
def unused_uris_processor(self, unused_uris: List[Tuple[str, UriType]]) -> None:
|
|
for uri, uri_type in unused_uris:
|
|
self._plugin_manager.plugins[str(uri_type)].uri_cache.mark_unused(uri)
|
|
|
|
def unused_runtime_env_processor(self, unused_runtime_env: str) -> None:
|
|
def delete_runtime_env():
|
|
del self._env_cache[unused_runtime_env]
|
|
self._logger.info("Runtime env %s deleted.", unused_runtime_env)
|
|
|
|
if unused_runtime_env in self._env_cache:
|
|
if not self._env_cache[unused_runtime_env].success:
|
|
loop = asyncio.get_event_loop()
|
|
# Cache the bad runtime env result by ttl seconds.
|
|
loop.call_later(
|
|
dashboard_consts.BAD_RUNTIME_ENV_CACHE_TTL_SECONDS,
|
|
delete_runtime_env,
|
|
)
|
|
else:
|
|
delete_runtime_env()
|
|
|
|
def get_or_create_logger(self, job_id: bytes):
|
|
job_id = job_id.decode()
|
|
if job_id not in self._per_job_logger_cache:
|
|
params = self._logging_params.copy()
|
|
params["filename"] = f"runtime_env_setup-{job_id}.log"
|
|
params["logger_name"] = f"runtime_env_{job_id}"
|
|
per_job_logger = setup_component_logger(**params)
|
|
self._per_job_logger_cache[job_id] = per_job_logger
|
|
return self._per_job_logger_cache[job_id]
|
|
|
|
async def GetOrCreateRuntimeEnv(self, request, context):
|
|
self._logger.debug(
|
|
f"Got request from {request.source_process} to increase "
|
|
"reference for runtime env: "
|
|
f"{request.serialized_runtime_env}."
|
|
)
|
|
|
|
async def _setup_runtime_env(
|
|
runtime_env: RuntimeEnv,
|
|
serialized_runtime_env,
|
|
serialized_allocated_resource_instances,
|
|
):
|
|
allocated_resource: dict = json.loads(
|
|
serialized_allocated_resource_instances or "{}"
|
|
)
|
|
# Use a separate logger for each job.
|
|
per_job_logger = self.get_or_create_logger(request.job_id)
|
|
# TODO(chenk008): Add log about allocated_resource to
|
|
# avoid lint error. That will be moved to cgroup plugin.
|
|
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
|
|
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
|
|
await self._container_manager.setup(
|
|
runtime_env, context, logger=per_job_logger
|
|
)
|
|
|
|
# Warn about unrecognized fields in the runtime env.
|
|
for name, _ in runtime_env.plugins():
|
|
if name not in self._plugin_manager.plugins:
|
|
per_job_logger.warning(
|
|
f"runtime_env field {name} is not recognized by "
|
|
"Ray and will be ignored. In the future, unrecognized "
|
|
"fields in the runtime_env will raise an exception."
|
|
)
|
|
|
|
"""Run setup for each plugin unless it has already been cached."""
|
|
for (
|
|
plugin_setup_context
|
|
) in self._plugin_manager.sorted_plugin_setup_contexts():
|
|
plugin = plugin_setup_context.class_instance
|
|
uri_cache = plugin_setup_context.uri_cache
|
|
await create_for_plugin_if_needed(
|
|
runtime_env, plugin, uri_cache, context, per_job_logger
|
|
)
|
|
|
|
return context
|
|
|
|
async def _create_runtime_env_with_retry(
|
|
runtime_env,
|
|
serialized_runtime_env,
|
|
serialized_allocated_resource_instances,
|
|
setup_timeout_seconds,
|
|
) -> Tuple[bool, str, str]:
|
|
"""
|
|
Create runtime env with retry times. This function won't raise exceptions.
|
|
|
|
Args:
|
|
runtime_env: The instance of RuntimeEnv class.
|
|
serialized_runtime_env: The serialized runtime env.
|
|
serialized_allocated_resource_instances: The serialized allocated
|
|
resource instances.
|
|
setup_timeout_seconds: The timeout of runtime environment creation.
|
|
|
|
Returns:
|
|
a tuple which contains result (bool), runtime env context (str), error
|
|
message(str).
|
|
|
|
"""
|
|
self._logger.info(
|
|
f"Creating runtime env: {serialized_env} with timeout "
|
|
f"{setup_timeout_seconds} seconds."
|
|
)
|
|
serialized_context = None
|
|
error_message = None
|
|
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
|
|
try:
|
|
# python 3.6 requires the type of input is `Future`,
|
|
# python 3.7+ only requires the type of input is `Awaitable`
|
|
# TODO(Catch-Bull): remove create_task when ray drop python 3.6
|
|
runtime_env_setup_task = create_task(
|
|
_setup_runtime_env(
|
|
runtime_env,
|
|
serialized_env,
|
|
request.serialized_allocated_resource_instances,
|
|
)
|
|
)
|
|
runtime_env_context = await asyncio.wait_for(
|
|
runtime_env_setup_task, timeout=setup_timeout_seconds
|
|
)
|
|
serialized_context = runtime_env_context.serialize()
|
|
error_message = None
|
|
break
|
|
except Exception as e:
|
|
err_msg = f"Failed to create runtime env {serialized_env}."
|
|
self._logger.exception(err_msg)
|
|
error_message = "".join(
|
|
traceback.format_exception(type(e), e, e.__traceback__)
|
|
)
|
|
if isinstance(e, asyncio.TimeoutError):
|
|
hint = (
|
|
f"Failed due to timeout; check runtime_env setup logs"
|
|
" and consider increasing `setup_timeout_seconds` beyond "
|
|
f"the default of {DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS}."
|
|
"For example: \n"
|
|
' runtime_env={"config": {"setup_timeout_seconds":'
|
|
" 1800}, ...}\n"
|
|
)
|
|
error_message = hint + error_message
|
|
await asyncio.sleep(
|
|
runtime_env_consts.RUNTIME_ENV_RETRY_INTERVAL_MS / 1000
|
|
)
|
|
if error_message:
|
|
self._logger.error(
|
|
"Runtime env creation failed for %d times, "
|
|
"don't retry any more.",
|
|
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
|
|
)
|
|
return False, None, error_message
|
|
else:
|
|
self._logger.info(
|
|
"Successfully created runtime env: %s, the context: %s",
|
|
serialized_env,
|
|
serialized_context,
|
|
)
|
|
return True, serialized_context, None
|
|
|
|
try:
|
|
serialized_env = request.serialized_runtime_env
|
|
runtime_env = RuntimeEnv.deserialize(serialized_env)
|
|
except Exception as e:
|
|
self._logger.exception(
|
|
"[Increase] Failed to parse runtime env: " f"{serialized_env}"
|
|
)
|
|
return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message="".join(
|
|
traceback.format_exception(type(e), e, e.__traceback__)
|
|
),
|
|
)
|
|
|
|
# Increase reference
|
|
self._reference_table.increase_reference(
|
|
runtime_env, serialized_env, request.source_process
|
|
)
|
|
|
|
if serialized_env not in self._env_locks:
|
|
# async lock to prevent the same env being concurrently installed
|
|
self._env_locks[serialized_env] = asyncio.Lock()
|
|
|
|
async with self._env_locks[serialized_env]:
|
|
if serialized_env in self._env_cache:
|
|
serialized_context = self._env_cache[serialized_env]
|
|
result = self._env_cache[serialized_env]
|
|
if result.success:
|
|
context = result.result
|
|
self._logger.info(
|
|
"Runtime env already created "
|
|
f"successfully. Env: {serialized_env}, "
|
|
f"context: {context}"
|
|
)
|
|
return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
|
|
serialized_runtime_env_context=context,
|
|
)
|
|
else:
|
|
error_message = result.result
|
|
self._logger.info(
|
|
"Runtime env already failed. "
|
|
f"Env: {serialized_env}, "
|
|
f"err: {error_message}"
|
|
)
|
|
# Recover the reference.
|
|
self._reference_table.decrease_reference(
|
|
runtime_env, serialized_env, request.source_process
|
|
)
|
|
return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message=error_message,
|
|
)
|
|
|
|
if SLEEP_FOR_TESTING_S:
|
|
self._logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
|
|
time.sleep(int(SLEEP_FOR_TESTING_S))
|
|
|
|
runtime_env_config = RuntimeEnvConfig.from_proto(request.runtime_env_config)
|
|
# accroding to the document of `asyncio.wait_for`,
|
|
# None means disable timeout logic
|
|
setup_timeout_seconds = (
|
|
None
|
|
if runtime_env_config["setup_timeout_seconds"] == -1
|
|
else runtime_env_config["setup_timeout_seconds"]
|
|
)
|
|
|
|
start = time.perf_counter()
|
|
(
|
|
successful,
|
|
serialized_context,
|
|
error_message,
|
|
) = await _create_runtime_env_with_retry(
|
|
runtime_env,
|
|
serialized_env,
|
|
request.serialized_allocated_resource_instances,
|
|
setup_timeout_seconds,
|
|
)
|
|
creation_time_ms = int(round((time.perf_counter() - start) * 1000, 0))
|
|
if not successful:
|
|
# Recover the reference.
|
|
self._reference_table.decrease_reference(
|
|
runtime_env, serialized_env, request.source_process
|
|
)
|
|
# Add the result to env cache.
|
|
self._env_cache[serialized_env] = CreatedEnvResult(
|
|
successful,
|
|
serialized_context if successful else error_message,
|
|
creation_time_ms,
|
|
)
|
|
# Reply the RPC
|
|
return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK
|
|
if successful
|
|
else agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
serialized_runtime_env_context=serialized_context,
|
|
error_message=error_message,
|
|
)
|
|
|
|
async def DeleteRuntimeEnvIfPossible(self, request, context):
|
|
self._logger.info(
|
|
f"Got request from {request.source_process} to decrease "
|
|
"reference for runtime env: "
|
|
f"{request.serialized_runtime_env}."
|
|
)
|
|
|
|
try:
|
|
runtime_env = RuntimeEnv.deserialize(request.serialized_runtime_env)
|
|
except Exception as e:
|
|
self._logger.exception(
|
|
"[Decrease] Failed to parse runtime env: "
|
|
f"{request.serialized_runtime_env}"
|
|
)
|
|
return runtime_env_agent_pb2.GetOrCreateRuntimeEnvReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
|
|
error_message="".join(
|
|
traceback.format_exception(type(e), e, e.__traceback__)
|
|
),
|
|
)
|
|
|
|
self._reference_table.decrease_reference(
|
|
runtime_env, request.serialized_runtime_env, request.source_process
|
|
)
|
|
|
|
return runtime_env_agent_pb2.DeleteRuntimeEnvIfPossibleReply(
|
|
status=agent_manager_pb2.AGENT_RPC_STATUS_OK
|
|
)
|
|
|
|
async def GetRuntimeEnvsInfo(self, request, context):
|
|
"""Return the runtime env information of the node."""
|
|
# TODO(sang): Currently, it only includes runtime_env information.
|
|
# We should include the URI information which includes,
|
|
# URIs
|
|
# Caller
|
|
# Ref counts
|
|
# Cache information
|
|
# Metrics (creation time & success)
|
|
# Deleted URIs
|
|
limit = request.limit if request.HasField("limit") else -1
|
|
runtime_env_states = defaultdict(ProtoRuntimeEnvState)
|
|
runtime_env_refs = self._reference_table.runtime_env_refs
|
|
for runtime_env, ref_cnt in runtime_env_refs.items():
|
|
runtime_env_states[runtime_env].runtime_env = runtime_env
|
|
runtime_env_states[runtime_env].ref_cnt = ref_cnt
|
|
for runtime_env, result in self._env_cache.items():
|
|
runtime_env_states[runtime_env].runtime_env = runtime_env
|
|
runtime_env_states[runtime_env].success = result.success
|
|
if not result.success:
|
|
runtime_env_states[runtime_env].error = result.result
|
|
runtime_env_states[runtime_env].creation_time_ms = result.creation_time_ms
|
|
|
|
reply = runtime_env_agent_pb2.GetRuntimeEnvsInfoReply()
|
|
count = 0
|
|
for runtime_env_state in runtime_env_states.values():
|
|
if limit != -1 and count >= limit:
|
|
break
|
|
count += 1
|
|
reply.runtime_env_states.append(runtime_env_state)
|
|
reply.total = len(runtime_env_states)
|
|
return reply
|
|
|
|
async def run(self, server):
|
|
if server:
|
|
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(
|
|
self, server
|
|
)
|
|
|
|
@staticmethod
|
|
def is_minimal_module():
|
|
return True
|