ray/dashboard/modules/runtime_env/runtime_env_agent.py
Jialing He 4a83bc3dc2
[runtime env] Support set timeout for runtime env setup (#23082)
Interface example:
```python
@ray.remote(runtime_env=RuntimeEnv(..., config=RuntimeEnvConfig(setup_timeout_s=10))
def f(): pass

@ray.remote(runtime_env={..., "config": {"setup_timeout_s": 10}})
def f(): pass
```

Support set timeout second for timeout of runtime environment creation.

Co-authored-by: 捕牛 <hejialing.hjl@antgroup.com>
2022-03-18 12:52:59 -05:00

354 lines
16 KiB
Python

import asyncio
import traceback
from collections import defaultdict
from dataclasses import dataclass
import json
import logging
import os
import time
from typing import Dict, Set
from ray._private.utils import import_attr
from ray.core.generated import runtime_env_agent_pb2
from ray.core.generated import runtime_env_agent_pb2_grpc
from ray.core.generated import agent_manager_pb2
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.modules.runtime_env.runtime_env_consts as runtime_env_consts
from ray.experimental.internal_kv import (
_internal_kv_initialized,
_initialize_internal_kv,
)
from ray._private.ray_logging import setup_component_logger
from ray._private.async_compat import create_task
from ray._private.runtime_env.pip import PipManager
from ray._private.runtime_env.conda import CondaManager
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.py_modules import PyModulesManager
from ray._private.runtime_env.working_dir import WorkingDirManager
from ray._private.runtime_env.container import ContainerManager
from ray._private.runtime_env.plugin import decode_plugin_uri
from ray._private.runtime_env.uri_cache import URICache
from ray.runtime_env import RuntimeEnv, RuntimeEnvConfig
default_logger = logging.getLogger(__name__)
# TODO(edoakes): this is used for unit tests. We should replace it with a
# better pluggability mechanism once available.
SLEEP_FOR_TESTING_S = os.environ.get("RAY_RUNTIME_ENV_SLEEP_FOR_TESTING_S")
# Sizes for the URI cache for each runtime_env field. Defaults to 10 GB.
WORKING_DIR_CACHE_SIZE_BYTES = int(
(1024 ** 3) * float(os.environ.get("RAY_RUNTIME_ENV_WORKING_DIR_CACHE_SIZE_GB", 10))
)
PY_MODULES_CACHE_SIZE_BYTES = int(
(1024 ** 3) * float(os.environ.get("RAY_RUNTIME_ENV_PY_MODULES_CACHE_SIZE_GB", 10))
)
CONDA_CACHE_SIZE_BYTES = int(
(1024 ** 3) * float(os.environ.get("RAY_RUNTIME_ENV_CONDA_CACHE_SIZE_GB", 10))
)
PIP_CACHE_SIZE_BYTES = int(
(1024 ** 3) * float(os.environ.get("RAY_RUNTIME_ENV_PIP_CACHE_SIZE_GB", 10))
)
@dataclass
class CreatedEnvResult:
# Whether or not the env was installed correctly.
success: bool
# If success is True, will be a serialized RuntimeEnvContext
# If success is False, will be an error message.
result: str
class RuntimeEnvAgent(
dashboard_utils.DashboardAgentModule,
runtime_env_agent_pb2_grpc.RuntimeEnvServiceServicer,
):
"""An RPC server to create and delete runtime envs.
Attributes:
dashboard_agent: The DashboardAgent object contains global config.
"""
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
self._runtime_env_dir = dashboard_agent.runtime_env_dir
self._logging_params = dashboard_agent.logging_params
self._per_job_logger_cache = dict()
# Cache the results of creating envs to avoid repeatedly calling into
# conda and other slow calls.
self._env_cache: Dict[str, CreatedEnvResult] = dict()
# Maps a serialized runtime env to a lock that is used
# to prevent multiple concurrent installs of the same env.
self._env_locks: Dict[str, asyncio.Lock] = dict()
# Keeps track of the URIs contained within each env so we can
# invalidate the env cache when a URI is deleted.
# This is a temporary mechanism until we have per-URI caching.
self._uris_to_envs: Dict[str, Set[str]] = defaultdict(set)
# Initialize internal KV to be used by the working_dir setup code.
_initialize_internal_kv(self._dashboard_agent.gcs_client)
assert _internal_kv_initialized()
self._pip_manager = PipManager(self._runtime_env_dir)
self._conda_manager = CondaManager(self._runtime_env_dir)
self._py_modules_manager = PyModulesManager(self._runtime_env_dir)
self._working_dir_manager = WorkingDirManager(self._runtime_env_dir)
self._container_manager = ContainerManager(dashboard_agent.temp_dir)
self._working_dir_uri_cache = URICache(
self._working_dir_manager.delete_uri, WORKING_DIR_CACHE_SIZE_BYTES
)
self._py_modules_uri_cache = URICache(
self._py_modules_manager.delete_uri, PY_MODULES_CACHE_SIZE_BYTES
)
self._conda_uri_cache = URICache(
self._conda_manager.delete_uri, CONDA_CACHE_SIZE_BYTES
)
self._pip_uri_cache = URICache(
self._pip_manager.delete_uri, PIP_CACHE_SIZE_BYTES
)
self._logger = default_logger
def get_or_create_logger(self, job_id: bytes):
job_id = job_id.decode()
if job_id not in self._per_job_logger_cache:
params = self._logging_params.copy()
params["filename"] = f"runtime_env_setup-{job_id}.log"
params["logger_name"] = f"runtime_env_{job_id}"
per_job_logger = setup_component_logger(**params)
self._per_job_logger_cache[job_id] = per_job_logger
return self._per_job_logger_cache[job_id]
async def CreateRuntimeEnv(self, request, context):
async def _setup_runtime_env(
serialized_runtime_env, serialized_allocated_resource_instances
):
runtime_env = RuntimeEnv.deserialize(serialized_runtime_env)
allocated_resource: dict = json.loads(
serialized_allocated_resource_instances or "{}"
)
# Use a separate logger for each job.
per_job_logger = self.get_or_create_logger(request.job_id)
# TODO(chenk008): Add log about allocated_resource to
# avoid lint error. That will be moved to cgroup plugin.
per_job_logger.debug(f"Worker has resource :" f"{allocated_resource}")
context = RuntimeEnvContext(env_vars=runtime_env.env_vars())
await self._container_manager.setup(
runtime_env, context, logger=per_job_logger
)
for (manager, uri_cache) in [
(self._working_dir_manager, self._working_dir_uri_cache),
(self._conda_manager, self._conda_uri_cache),
(self._pip_manager, self._pip_uri_cache),
]:
uri = manager.get_uri(runtime_env)
if uri is not None:
if uri not in uri_cache:
per_job_logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await manager.create(
uri, runtime_env, context, logger=per_job_logger
)
uri_cache.add(uri, size_bytes, logger=per_job_logger)
else:
per_job_logger.debug(f"Cache hit for URI {uri}.")
uri_cache.mark_used(uri, logger=per_job_logger)
manager.modify_context(uri, runtime_env, context)
# Set up py_modules. For now, py_modules uses multiple URIs so
# the logic is slightly different from working_dir, conda, and
# pip above.
py_modules_uris = self._py_modules_manager.get_uris(runtime_env)
if py_modules_uris is not None:
for uri in py_modules_uris:
if uri not in self._py_modules_uri_cache:
per_job_logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await self._py_modules_manager.create(
uri, runtime_env, context, logger=per_job_logger
)
self._py_modules_uri_cache.add(
uri, size_bytes, logger=per_job_logger
)
else:
per_job_logger.debug(f"Cache hit for URI {uri}.")
self._py_modules_uri_cache.mark_used(uri, logger=per_job_logger)
self._py_modules_manager.modify_context(
py_modules_uris, runtime_env, context
)
# Add the mapping of URIs -> the serialized environment to be
# used for cache invalidation.
if runtime_env.working_dir_uri():
uri = runtime_env.working_dir_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.py_modules_uris():
for uri in runtime_env.py_modules_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.conda_uri():
uri = runtime_env.conda_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.pip_uri():
uri = runtime_env.pip_uri()
self._uris_to_envs[uri].add(serialized_runtime_env)
if runtime_env.plugin_uris():
for uri in runtime_env.plugin_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)
def setup_plugins():
# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
per_job_logger.debug(
f"Setting up runtime env plugin {plugin_class_path}"
)
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create(
"uri not implemented", json.loads(config), context
)
plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)
loop = asyncio.get_event_loop()
# Plugins setup method is sync process, running in other threads
# is to avoid blocks asyncio loop
await loop.run_in_executor(None, setup_plugins)
return context
serialized_env = request.serialized_runtime_env
runtime_env_config = request.runtime_env_config
if serialized_env not in self._env_locks:
# async lock to prevent the same env being concurrently installed
self._env_locks[serialized_env] = asyncio.Lock()
async with self._env_locks[serialized_env]:
if serialized_env in self._env_cache:
serialized_context = self._env_cache[serialized_env]
result = self._env_cache[serialized_env]
if result.success:
context = result.result
self._logger.info(
"Runtime env already created "
f"successfully. Env: {serialized_env}, "
f"context: {context}"
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
serialized_runtime_env_context=context,
)
else:
error_message = result.result
self._logger.info(
"Runtime env already failed. "
f"Env: {serialized_env}, "
f"err: {error_message}"
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message,
)
if SLEEP_FOR_TESTING_S:
self._logger.info(f"Sleeping for {SLEEP_FOR_TESTING_S}s.")
time.sleep(int(SLEEP_FOR_TESTING_S))
self._logger.info(f"Creating runtime env: {serialized_env}.")
runtime_env_context: RuntimeEnvContext = None
error_message = None
runtime_env_config = RuntimeEnvConfig.from_proto(runtime_env_config)
# accroding to the document of `asyncio.wait_for`,
# None means disable timeout logic
setup_timeout_seconds = (
None
if runtime_env_config["setup_timeout_seconds"] == -1
else runtime_env_config["setup_timeout_seconds"]
)
for _ in range(runtime_env_consts.RUNTIME_ENV_RETRY_TIMES):
try:
# python 3.6 requires the type of input is `Future`,
# python 3.7+ only requires the type of input is `Awaitable`
# TODO(Catch-Bull): remove create_task when ray drop python 3.6
runtime_env_setup_task = create_task(
_setup_runtime_env(
serialized_env,
request.serialized_allocated_resource_instances,
)
)
runtime_env_context = await asyncio.wait_for(
runtime_env_setup_task, timeout=setup_timeout_seconds
)
error_message = None
break
except Exception as e:
err_msg = f"Failed to create runtime env {serialized_env}."
self._logger.exception(err_msg)
error_message = "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
await asyncio.sleep(
runtime_env_consts.RUNTIME_ENV_RETRY_INTERVAL_MS / 1000
)
if error_message:
self._logger.error(
"Runtime env creation failed for %d times, "
"don't retry any more.",
runtime_env_consts.RUNTIME_ENV_RETRY_TIMES,
)
self._env_cache[serialized_env] = CreatedEnvResult(False, error_message)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_FAILED,
error_message=error_message,
)
serialized_context = runtime_env_context.serialize()
self._env_cache[serialized_env] = CreatedEnvResult(True, serialized_context)
self._logger.info(
"Successfully created runtime env: %s, the context: %s",
serialized_env,
serialized_context,
)
return runtime_env_agent_pb2.CreateRuntimeEnvReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK,
serialized_runtime_env_context=serialized_context,
)
async def DeleteURIs(self, request, context):
self._logger.info(f"Got request to mark URIs unused: {request.uris}.")
for plugin_uri in request.uris:
plugin, uri = decode_plugin_uri(plugin_uri)
# Invalidate the env cache for any envs that contain this URI.
for env in self._uris_to_envs.get(uri, []):
if env in self._env_cache:
del self._env_cache[env]
if plugin == "working_dir":
self._working_dir_uri_cache.mark_unused(uri)
elif plugin == "py_modules":
self._py_modules_uri_cache.mark_unused(uri)
elif plugin == "conda":
self._conda_uri_cache.mark_unused(uri)
elif plugin == "pip":
self._pip_uri_cache.mark_unused(uri)
else:
raise ValueError(
"RuntimeEnvAgent received DeleteURI request "
f"for unsupported plugin {plugin}. URI: {uri}"
)
return runtime_env_agent_pb2.DeleteURIsReply(
status=agent_manager_pb2.AGENT_RPC_STATUS_OK
)
async def run(self, server):
if server:
runtime_env_agent_pb2_grpc.add_RuntimeEnvServiceServicer_to_server(
self, server
)
@staticmethod
def is_minimal_module():
return True