mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[runtime env] Add URI support for plugins (#26746)
This commit is contained in:
parent
274ea2a1ba
commit
60f33777a2
11 changed files with 448 additions and 251 deletions
|
@ -6,7 +6,6 @@ import time
|
|||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Set, Tuple
|
||||
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
|
||||
|
||||
|
@ -20,10 +19,12 @@ from ray._private.runtime_env.container import ContainerManager
|
|||
from ray._private.runtime_env.context import RuntimeEnvContext
|
||||
from ray._private.runtime_env.java_jars import JavaJarsPlugin
|
||||
from ray._private.runtime_env.pip import PipPlugin
|
||||
from ray._private.runtime_env.plugin import PluginCacheManager
|
||||
from ray._private.runtime_env.plugin import (
|
||||
RuntimeEnvPlugin,
|
||||
create_for_plugin_if_needed,
|
||||
)
|
||||
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
|
||||
from ray._private.runtime_env.py_modules import PyModulesPlugin
|
||||
from ray._private.runtime_env.uri_cache import URICache
|
||||
from ray._private.runtime_env.working_dir import WorkingDirPlugin
|
||||
from ray.core.generated import (
|
||||
agent_manager_pb2,
|
||||
|
@ -53,12 +54,8 @@ class CreatedEnvResult:
|
|||
creation_time_ms: int
|
||||
|
||||
|
||||
class UriType(Enum):
|
||||
WORKING_DIR = "working_dir"
|
||||
PY_MODULES = "py_modules"
|
||||
PIP = "pip"
|
||||
CONDA = "conda"
|
||||
JAVA_JARS = "java_jars"
|
||||
# e.g., "working_dir"
|
||||
UriType = str
|
||||
|
||||
|
||||
class ReferenceTable:
|
||||
|
@ -198,48 +195,37 @@ class RuntimeEnvAgent(
|
|||
# TODO(architkulkarni): "base plugins" and third-party plugins should all go
|
||||
# through the same code path. We should never need to refer to
|
||||
# self._xxx_plugin, we should just iterate through self._plugins.
|
||||
self._base_plugins = [
|
||||
self._base_plugins: List[RuntimeEnvPlugin] = [
|
||||
self._working_dir_plugin,
|
||||
self._pip_plugin,
|
||||
self._conda_plugin,
|
||||
self._py_modules_plugin,
|
||||
self._java_jars_plugin,
|
||||
]
|
||||
self._uri_caches = {}
|
||||
self._base_plugin_cache_managers = {}
|
||||
self._plugin_manager = RuntimeEnvPluginManager()
|
||||
for plugin in self._base_plugins:
|
||||
# Set the max size for the cache. Defaults to 10 GB.
|
||||
cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
|
||||
cache_size_bytes = int(
|
||||
(1024 ** 3) * float(os.environ.get(cache_size_env_var, 10))
|
||||
)
|
||||
self._uri_caches[plugin.name] = URICache(
|
||||
plugin.delete_uri, cache_size_bytes
|
||||
)
|
||||
self._base_plugin_cache_managers[plugin.name] = PluginCacheManager(
|
||||
plugin, self._uri_caches[plugin.name]
|
||||
)
|
||||
self._plugin_manager.add_plugin(plugin)
|
||||
|
||||
self._reference_table = ReferenceTable(
|
||||
self.uris_parser,
|
||||
self.unused_uris_processor,
|
||||
self.unused_runtime_env_processor,
|
||||
)
|
||||
self._runtime_env_plugin_manager = RuntimeEnvPluginManager()
|
||||
|
||||
self._logger = default_logger
|
||||
|
||||
def uris_parser(self, runtime_env):
|
||||
result = list()
|
||||
for plugin in self._base_plugins:
|
||||
for name, plugin_setup_context in self._plugin_manager.plugins.items():
|
||||
plugin = plugin_setup_context.class_instance
|
||||
uris = plugin.get_uris(runtime_env)
|
||||
for uri in uris:
|
||||
result.append((uri, UriType(plugin.name)))
|
||||
result.append((uri, UriType(name)))
|
||||
return result
|
||||
|
||||
def unused_uris_processor(self, unused_uris: List[Tuple[str, UriType]]) -> None:
|
||||
for uri, uri_type in unused_uris:
|
||||
self._uri_caches[uri_type.value].mark_unused(uri)
|
||||
self._plugin_manager.plugins[str(uri_type)].uri_cache.mark_unused(uri)
|
||||
|
||||
def unused_runtime_env_processor(self, unused_runtime_env: str) -> None:
|
||||
def delete_runtime_env():
|
||||
|
@ -275,7 +261,9 @@ class RuntimeEnvAgent(
|
|||
)
|
||||
|
||||
async def _setup_runtime_env(
|
||||
runtime_env, serialized_runtime_env, serialized_allocated_resource_instances
|
||||
runtime_env: RuntimeEnv,
|
||||
serialized_runtime_env,
|
||||
serialized_allocated_resource_instances,
|
||||
):
|
||||
allocated_resource: dict = json.loads(
|
||||
serialized_allocated_resource_instances or "{}"
|
||||
|
@ -290,39 +278,25 @@ class RuntimeEnvAgent(
|
|||
runtime_env, context, logger=per_job_logger
|
||||
)
|
||||
|
||||
for manager in self._base_plugin_cache_managers.values():
|
||||
await manager.create_if_needed(
|
||||
runtime_env, context, logger=per_job_logger
|
||||
# Warn about unrecognized fields in the runtime env.
|
||||
for name, _ in runtime_env.plugins():
|
||||
if name not in self._plugin_manager.plugins:
|
||||
per_job_logger.warning(
|
||||
f"runtime_env field {name} is not recognized by "
|
||||
"Ray and will be ignored. In the future, unrecognized "
|
||||
"fields in the runtime_env will raise an exception."
|
||||
)
|
||||
|
||||
"""Run setup for each plugin unless it has already been cached."""
|
||||
for (
|
||||
plugin_setup_context
|
||||
) in self._plugin_manager.sorted_plugin_setup_contexts():
|
||||
plugin = plugin_setup_context.class_instance
|
||||
uri_cache = plugin_setup_context.uri_cache
|
||||
await create_for_plugin_if_needed(
|
||||
runtime_env, plugin, uri_cache, context, per_job_logger
|
||||
)
|
||||
|
||||
def setup_plugins():
|
||||
# Run setup function from all the plugins
|
||||
if runtime_env.plugins():
|
||||
for (
|
||||
setup_context
|
||||
) in self._runtime_env_plugin_manager.sorted_plugin_setup_contexts(
|
||||
runtime_env.plugins()
|
||||
):
|
||||
per_job_logger.debug(
|
||||
f"Setting up runtime env plugin {setup_context.name}"
|
||||
)
|
||||
# TODO(architkulkarni): implement uri support
|
||||
setup_context.class_instance.validate(runtime_env)
|
||||
setup_context.class_instance.create(
|
||||
"uri not implemented", setup_context.config, context
|
||||
)
|
||||
setup_context.class_instance.modify_context(
|
||||
"uri not implemented",
|
||||
setup_context.config,
|
||||
context,
|
||||
per_job_logger,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
# Plugins setup method is sync process, running in other threads
|
||||
# is to avoid blocking asyncio loop
|
||||
await loop.run_in_executor(None, setup_plugins)
|
||||
|
||||
return context
|
||||
|
||||
async def _create_runtime_env_with_retry(
|
||||
|
|
|
@ -12,10 +12,10 @@ def test_reference_table():
|
|||
|
||||
def uris_parser(runtime_env) -> Tuple[str, UriType]:
|
||||
result = list()
|
||||
result.append((runtime_env.working_dir(), UriType.WORKING_DIR))
|
||||
result.append((runtime_env.working_dir(), "working_dir"))
|
||||
py_module_uris = runtime_env.py_modules()
|
||||
for uri in py_module_uris:
|
||||
result.append((uri, UriType.PY_MODULES))
|
||||
result.append((uri, "py_modules"))
|
||||
return result
|
||||
|
||||
def unused_uris_processor(unused_uris: List[Tuple[str, UriType]]) -> None:
|
||||
|
@ -57,8 +57,8 @@ def test_reference_table():
|
|||
)
|
||||
|
||||
# Remove runtime env 1
|
||||
expected_unused_uris.append(("s3://working_dir_1.zip", UriType.WORKING_DIR))
|
||||
expected_unused_uris.append(("s3://py_module_B.zip", UriType.PY_MODULES))
|
||||
expected_unused_uris.append(("s3://working_dir_1.zip", "working_dir"))
|
||||
expected_unused_uris.append(("s3://py_module_B.zip", "py_modules"))
|
||||
expected_unused_runtime_env = runtime_env_1.serialize()
|
||||
reference_table.decrease_reference(
|
||||
runtime_env_1, runtime_env_1.serialize(), "raylet"
|
||||
|
@ -67,9 +67,9 @@ def test_reference_table():
|
|||
assert not expected_unused_runtime_env
|
||||
|
||||
# Remove runtime env 2
|
||||
expected_unused_uris.append(("s3://working_dir_2.zip", UriType.WORKING_DIR))
|
||||
expected_unused_uris.append(("s3://py_module_A.zip", UriType.PY_MODULES))
|
||||
expected_unused_uris.append(("s3://py_module_C.zip", UriType.PY_MODULES))
|
||||
expected_unused_uris.append(("s3://working_dir_2.zip", "working_dir"))
|
||||
expected_unused_uris.append(("s3://py_module_A.zip", "py_modules"))
|
||||
expected_unused_uris.append(("s3://py_module_C.zip", "py_modules"))
|
||||
expected_unused_runtime_env = runtime_env_2.serialize()
|
||||
reference_table.decrease_reference(
|
||||
runtime_env_2, runtime_env_2.serialize(), "raylet"
|
||||
|
|
|
@ -315,6 +315,13 @@ class CondaPlugin(RuntimeEnvPlugin):
|
|||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger = default_logger,
|
||||
) -> int:
|
||||
if uri is None:
|
||||
# The "conda" field is the name of an existing conda env, so no
|
||||
# need to create one.
|
||||
# TODO(architkulkarni): Try "conda activate" here to see if the
|
||||
# env exists, and raise an exception if it doesn't.
|
||||
return 0
|
||||
|
||||
# Currently create method is still a sync process, to avoid blocking
|
||||
# the loop, need to run this function in another thread.
|
||||
# TODO(Catch-Bull): Refactor method create into an async process, and
|
||||
|
|
|
@ -457,7 +457,7 @@ class PipPlugin(RuntimeEnvPlugin):
|
|||
uris: List[str],
|
||||
runtime_env: "RuntimeEnv", # noqa: F821
|
||||
context: RuntimeEnvContext,
|
||||
logger: Optional[logging.Logger] = default_logger,
|
||||
logger: logging.Logger = default_logger,
|
||||
):
|
||||
if not runtime_env.has_pip():
|
||||
return
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import os
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import List, Dict, Tuple, Any
|
||||
from typing import List, Dict, Optional, Any, Type
|
||||
|
||||
from ray._private.runtime_env.context import RuntimeEnvContext
|
||||
from ray._private.runtime_env.uri_cache import URICache
|
||||
|
@ -28,28 +28,25 @@ class RuntimeEnvPlugin(ABC):
|
|||
priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY
|
||||
|
||||
@staticmethod
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
"""Validate user entry and returns a URI uniquely describing resource.
|
||||
|
||||
This method will be called at ``f.options(runtime_env=...)`` or
|
||||
``ray.init(runtime_env=...)`` time and it should check the runtime env
|
||||
dictionary for any errors. For example, it can raise "TypeError:
|
||||
expected string for "conda" field".
|
||||
def validate(runtime_env_dict: dict) -> None:
|
||||
"""Validate user entry for this plugin.
|
||||
|
||||
Args:
|
||||
runtime_env_dict(dict): the entire dictionary passed in by user.
|
||||
|
||||
Returns:
|
||||
uri(str): a URI uniquely describing this resource (e.g., a hash of
|
||||
the conda spec).
|
||||
runtime_env_dict: the user-supplied runtime environment dict.
|
||||
Raises:
|
||||
ValueError: if the validation fails.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
|
||||
def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
|
||||
return None
|
||||
return []
|
||||
|
||||
def create(
|
||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||
async def create(
|
||||
self,
|
||||
uri: Optional[str],
|
||||
runtime_env: "RuntimeEnv", # noqa: F821
|
||||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger,
|
||||
) -> float:
|
||||
"""Create and install the runtime environment.
|
||||
|
||||
|
@ -57,9 +54,10 @@ class RuntimeEnvPlugin(ABC):
|
|||
used as a caching mechanism.
|
||||
|
||||
Args:
|
||||
uri(str): a URI uniquely describing this resource.
|
||||
runtime_env(RuntimeEnv): the runtime env protobuf.
|
||||
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
|
||||
uri: A URI uniquely describing this resource.
|
||||
runtime_env: The RuntimeEnv object.
|
||||
context: auxiliary information supplied by Ray.
|
||||
logger: A logger to log messages during the context modification.
|
||||
|
||||
Returns:
|
||||
the disk space taken up by this plugin installation for this
|
||||
|
@ -81,9 +79,10 @@ class RuntimeEnvPlugin(ABC):
|
|||
startup, or add new environment variables.
|
||||
|
||||
Args:
|
||||
uris(List[str]): a URIs used by this resource.
|
||||
runtime_env(RuntimeEnv): the runtime env protobuf.
|
||||
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
|
||||
uris: The URIs used by this resource.
|
||||
runtime_env: The RuntimeEnv object.
|
||||
context: Auxiliary information supplied by Ray.
|
||||
logger: A logger to log messages during the context modification.
|
||||
"""
|
||||
return
|
||||
|
||||
|
@ -91,8 +90,7 @@ class RuntimeEnvPlugin(ABC):
|
|||
"""Delete the the runtime environment given uri.
|
||||
|
||||
Args:
|
||||
uri(str): a URI uniquely describing this resource.
|
||||
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
|
||||
uri: a URI uniquely describing this resource.
|
||||
|
||||
Returns:
|
||||
the amount of space reclaimed by the deletion.
|
||||
|
@ -101,29 +99,59 @@ class RuntimeEnvPlugin(ABC):
|
|||
|
||||
|
||||
class PluginSetupContext:
|
||||
def __init__(self, name: str, config: Any, class_instance: object):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
class_instance: RuntimeEnvPlugin,
|
||||
priority: int,
|
||||
uri_cache: URICache,
|
||||
):
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.class_instance = class_instance
|
||||
self.priority = priority
|
||||
self.uri_cache = uri_cache
|
||||
|
||||
|
||||
class RuntimeEnvPluginManager:
|
||||
"""This manager is used to load plugins in runtime env agent."""
|
||||
|
||||
class Context:
|
||||
def __init__(self, class_instance, priority):
|
||||
self.class_instance = class_instance
|
||||
self.priority = priority
|
||||
|
||||
def __init__(self):
|
||||
self.plugins: Dict[str, RuntimeEnvPluginManager.Context] = {}
|
||||
self.plugins: Dict[str, PluginSetupContext] = {}
|
||||
plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
|
||||
if plugin_config_str:
|
||||
plugin_configs = json.loads(plugin_config_str)
|
||||
self.load_plugins(plugin_configs)
|
||||
|
||||
def load_plugins(self, plugin_configs: List[Dict]):
|
||||
"""Load runtime env plugins"""
|
||||
def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None:
|
||||
if not issubclass(plugin_class, RuntimeEnvPlugin):
|
||||
raise RuntimeError(
|
||||
f"Invalid runtime env plugin class {plugin_class}. "
|
||||
"The plugin class must inherit "
|
||||
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
|
||||
)
|
||||
if not plugin_class.name:
|
||||
raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.")
|
||||
if plugin_class.name in self.plugins:
|
||||
raise RuntimeError(
|
||||
f"The name of runtime env plugin {plugin_class} conflicts "
|
||||
f"with {self.plugins[plugin_class.name]}.",
|
||||
)
|
||||
|
||||
def validate_priority(self, priority: Any) -> None:
|
||||
if (
|
||||
not isinstance(priority, int)
|
||||
or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
|
||||
or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Invalid runtime env priority {priority}, "
|
||||
"it should be an integer between "
|
||||
f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
|
||||
f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
|
||||
)
|
||||
|
||||
def load_plugins(self, plugin_configs: List[Dict]) -> None:
|
||||
"""Load runtime env plugins and create URI caches for them."""
|
||||
for plugin_config in plugin_configs:
|
||||
if (
|
||||
not isinstance(plugin_config, dict)
|
||||
|
@ -135,21 +163,7 @@ class RuntimeEnvPluginManager:
|
|||
f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field."
|
||||
)
|
||||
plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME])
|
||||
if not issubclass(plugin_class, RuntimeEnvPlugin):
|
||||
raise RuntimeError(
|
||||
f"Invalid runtime env plugin class {plugin_class}. "
|
||||
"The plugin class must inherit "
|
||||
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
|
||||
)
|
||||
if not plugin_class.name:
|
||||
raise RuntimeError(
|
||||
f"No valid name in runtime env plugin {plugin_class}."
|
||||
)
|
||||
if plugin_class.name in self.plugins:
|
||||
raise RuntimeError(
|
||||
f"The name of runtime env plugin {plugin_class} conflicts "
|
||||
f"with {self.plugins[plugin_class.name]}.",
|
||||
)
|
||||
self.validate_plugin_class(plugin_class)
|
||||
|
||||
# The priority should be an integer between 0 and 100.
|
||||
# The default priority is 10. A smaller number indicates a
|
||||
|
@ -158,73 +172,86 @@ class RuntimeEnvPluginManager:
|
|||
priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME]
|
||||
else:
|
||||
priority = plugin_class.priority
|
||||
if (
|
||||
not isinstance(priority, int)
|
||||
or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
|
||||
or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Invalid runtime env priority {priority}, "
|
||||
"it should be an integer between "
|
||||
f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
|
||||
f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
|
||||
)
|
||||
self.validate_priority(priority)
|
||||
|
||||
self.plugins[plugin_class.name] = RuntimeEnvPluginManager.Context(
|
||||
plugin_class(), priority
|
||||
class_instance = plugin_class()
|
||||
self.plugins[plugin_class.name] = PluginSetupContext(
|
||||
plugin_class.name,
|
||||
class_instance,
|
||||
priority,
|
||||
self.create_uri_cache_for_plugin(class_instance),
|
||||
)
|
||||
|
||||
def sorted_plugin_setup_contexts(
|
||||
self, inputs: List[Tuple[str, Any]]
|
||||
) -> List[PluginSetupContext]:
|
||||
used_plugins = []
|
||||
for name, config in inputs:
|
||||
if name not in self.plugins:
|
||||
default_logger.error(
|
||||
f"runtime_env field {name} is not recognized by "
|
||||
"Ray and will be ignored. In the future, unrecognized "
|
||||
"fields in the runtime_env will raise an exception."
|
||||
)
|
||||
continue
|
||||
used_plugins.append(
|
||||
(
|
||||
name,
|
||||
config,
|
||||
self.plugins[name].class_instance,
|
||||
self.plugins[name].priority,
|
||||
)
|
||||
)
|
||||
sort_used_plugins = sorted(used_plugins, key=lambda x: x[3], reverse=False)
|
||||
return [
|
||||
PluginSetupContext(name, config, class_instance)
|
||||
for name, config, class_instance, _ in sort_used_plugins
|
||||
]
|
||||
def add_plugin(self, plugin: RuntimeEnvPlugin) -> None:
|
||||
"""Add a plugin to the manager and create a URI cache for it.
|
||||
|
||||
Args:
|
||||
plugin: The class instance of the plugin.
|
||||
"""
|
||||
plugin_class = type(plugin)
|
||||
self.validate_plugin_class(plugin_class)
|
||||
self.validate_priority(plugin_class.priority)
|
||||
self.plugins[plugin_class.name] = PluginSetupContext(
|
||||
plugin_class.name,
|
||||
plugin,
|
||||
plugin_class.priority,
|
||||
self.create_uri_cache_for_plugin(plugin),
|
||||
)
|
||||
|
||||
def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache:
|
||||
"""Create a URI cache for a plugin.
|
||||
|
||||
Args:
|
||||
plugin_name: The name of the plugin.
|
||||
|
||||
Returns:
|
||||
The created URI cache for the plugin.
|
||||
"""
|
||||
# Set the max size for the cache. Defaults to 10 GB.
|
||||
cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
|
||||
cache_size_bytes = int(
|
||||
(1024 ** 3) * float(os.environ.get(cache_size_env_var, 10))
|
||||
)
|
||||
return URICache(plugin.delete_uri, cache_size_bytes)
|
||||
|
||||
def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]:
|
||||
"""Get the sorted plugin setup contexts, sorted by increasing priority.
|
||||
|
||||
Returns:
|
||||
The sorted plugin setup contexts.
|
||||
"""
|
||||
return sorted(self.plugins.values(), key=lambda x: x.priority)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PluginCacheManager:
|
||||
"""Manages a plugin and a cache for its local resources."""
|
||||
async def create_for_plugin_if_needed(
|
||||
runtime_env,
|
||||
plugin: RuntimeEnvPlugin,
|
||||
uri_cache: URICache,
|
||||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger = default_logger,
|
||||
):
|
||||
"""Set up the environment using the plugin if not already set up and cached."""
|
||||
if plugin.name not in runtime_env or runtime_env[plugin.name] is None:
|
||||
return
|
||||
|
||||
def __init__(self, plugin: RuntimeEnvPlugin, uri_cache: URICache):
|
||||
self._plugin = plugin
|
||||
self._uri_cache = uri_cache
|
||||
plugin.validate(runtime_env)
|
||||
|
||||
async def create_if_needed(
|
||||
self,
|
||||
runtime_env: "RuntimeEnv", # noqa: F821
|
||||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger = default_logger,
|
||||
):
|
||||
uris = self._plugin.get_uris(runtime_env)
|
||||
for uri in uris:
|
||||
if uri not in self._uri_cache:
|
||||
logger.debug(f"Cache miss for URI {uri}.")
|
||||
size_bytes = await self._plugin.create(
|
||||
uri, runtime_env, context, logger=logger
|
||||
)
|
||||
self._uri_cache.add(uri, size_bytes, logger=logger)
|
||||
else:
|
||||
logger.debug(f"Cache hit for URI {uri}.")
|
||||
self._uri_cache.mark_used(uri, logger=logger)
|
||||
uris = plugin.get_uris(runtime_env)
|
||||
|
||||
self._plugin.modify_context(uris, runtime_env, context)
|
||||
if not uris:
|
||||
logger.debug(
|
||||
f"No URIs for runtime env plugin {plugin.name}; "
|
||||
"create always without checking the cache."
|
||||
)
|
||||
await plugin.create(None, runtime_env, context, logger=logger)
|
||||
|
||||
for uri in uris:
|
||||
if uri not in uri_cache:
|
||||
logger.debug(f"Cache miss for URI {uri}.")
|
||||
size_bytes = await plugin.create(uri, runtime_env, context, logger=logger)
|
||||
uri_cache.add(uri, size_bytes, logger=logger)
|
||||
else:
|
||||
logger.debug(f"Cache hit for URI {uri}.")
|
||||
uri_cache.mark_used(uri, logger=logger)
|
||||
|
||||
plugin.modify_context(uris, runtime_env, context, logger)
|
||||
|
|
|
@ -134,10 +134,10 @@ class WorkingDirPlugin(RuntimeEnvPlugin):
|
|||
|
||||
async def create(
|
||||
self,
|
||||
uri: str,
|
||||
uri: Optional[str],
|
||||
runtime_env: dict,
|
||||
context: RuntimeEnvContext,
|
||||
logger: Optional[logging.Logger] = default_logger,
|
||||
logger: logging.Logger = default_logger,
|
||||
) -> int:
|
||||
local_dir = await download_and_unpack_package(
|
||||
uri, self._resources_dir, self._gcs_aio_client, logger=logger
|
||||
|
@ -145,7 +145,11 @@ class WorkingDirPlugin(RuntimeEnvPlugin):
|
|||
return get_directory_size_bytes(local_dir)
|
||||
|
||||
def modify_context(
|
||||
self, uris: List[str], runtime_env_dict: Dict, context: RuntimeEnvContext
|
||||
self,
|
||||
uris: List[str],
|
||||
runtime_env_dict: Dict,
|
||||
context: RuntimeEnvContext,
|
||||
logger: Optional[logging.Logger] = default_logger,
|
||||
):
|
||||
if not uris:
|
||||
return
|
||||
|
|
|
@ -1445,12 +1445,10 @@ def get_runtime_env_info(
|
|||
|
||||
proto_runtime_env_info = ProtoRuntimeEnvInfo()
|
||||
|
||||
if runtime_env.get_working_dir_uri():
|
||||
proto_runtime_env_info.uris.working_dir_uri = runtime_env.get_working_dir_uri()
|
||||
if len(runtime_env.get_py_modules_uris()) > 0:
|
||||
proto_runtime_env_info.uris.py_modules_uris[
|
||||
:
|
||||
] = runtime_env.get_py_modules_uris()
|
||||
if runtime_env.working_dir_uri():
|
||||
proto_runtime_env_info.uris.working_dir_uri = runtime_env.working_dir_uri()
|
||||
if len(runtime_env.py_modules_uris()) > 0:
|
||||
proto_runtime_env_info.uris.py_modules_uris[:] = runtime_env.py_modules_uris()
|
||||
|
||||
# TODO(Catch-Bull): overload `__setitem__` for `RuntimeEnv`, change the
|
||||
# runtime_env of all internal code from dict to RuntimeEnv.
|
||||
|
|
|
@ -2044,7 +2044,7 @@ def connect(
|
|||
)
|
||||
# In client mode, if we use runtime envs with "working_dir", then
|
||||
# it'll be handled automatically. Otherwise, add the current dir.
|
||||
if not job_config.client_job and not job_config.runtime_env_has_uris():
|
||||
if not job_config.client_job and not job_config.runtime_env_has_working_dir():
|
||||
current_directory = os.path.abspath(os.path.curdir)
|
||||
worker.run_function_on_all_workers(
|
||||
lambda worker_info: sys.path.insert(1, current_directory)
|
||||
|
|
|
@ -126,9 +126,8 @@ class JobConfig:
|
|||
|
||||
return self._cached_pb
|
||||
|
||||
def runtime_env_has_uris(self):
|
||||
"""Whether there are uris in runtime env or not"""
|
||||
return self._validate_runtime_env().has_uris()
|
||||
def runtime_env_has_working_dir(self):
|
||||
return self._validate_runtime_env().has_working_dir()
|
||||
|
||||
def get_serialized_runtime_env(self) -> str:
|
||||
"""Return the JSON-serialized parsed runtime env dict"""
|
||||
|
|
|
@ -345,30 +345,6 @@ class RuntimeEnv(dict):
|
|||
if all(val is None for val in self.values()):
|
||||
self.clear()
|
||||
|
||||
def get_uris(self) -> List[str]:
|
||||
# TODO(architkulkarni): this should programmatically be extended with
|
||||
# URIs from all plugins.
|
||||
plugin_uris = []
|
||||
if "working_dir" in self:
|
||||
plugin_uris.append(self["working_dir"])
|
||||
if "py_modules" in self:
|
||||
for uri in self["py_modules"]:
|
||||
plugin_uris.append(uri)
|
||||
if "conda" in self:
|
||||
uri = get_conda_uri(self)
|
||||
if uri is not None:
|
||||
plugin_uris.append(uri)
|
||||
if "pip" in self:
|
||||
uri = get_pip_uri(self)
|
||||
if uri is not None:
|
||||
plugin_uris.append(uri)
|
||||
|
||||
def get_working_dir_uri(self) -> str:
|
||||
return self.get("working_dir", None)
|
||||
|
||||
def get_py_modules_uris(self) -> List[str]:
|
||||
return self.get("py_modules", [])
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if is_dataclass(value):
|
||||
jsonable_type = asdict(value)
|
||||
|
@ -416,16 +392,8 @@ class RuntimeEnv(dict):
|
|||
|
||||
return runtime_env_dict
|
||||
|
||||
def has_uris(self) -> bool:
|
||||
if (
|
||||
self.working_dir_uri()
|
||||
or self.py_modules_uris()
|
||||
or self.conda_uri()
|
||||
or self.pip_uri()
|
||||
or self.plugin_uris()
|
||||
):
|
||||
return True
|
||||
return False
|
||||
def has_working_dir(self) -> bool:
|
||||
return self.get("working_dir") is not None
|
||||
|
||||
def working_dir_uri(self) -> Optional[str]:
|
||||
return self.get("working_dir")
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
import tempfile
|
||||
import json
|
||||
from time import sleep
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
@ -13,6 +16,7 @@ from ray._private.runtime_env.context import RuntimeEnvContext
|
|||
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
|
||||
from ray._private.test_utils import enable_external_redis, wait_for_condition
|
||||
from ray.exceptions import RuntimeEnvSetupError
|
||||
from ray.runtime_env.runtime_env import RuntimeEnv
|
||||
|
||||
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
|
||||
MY_PLUGIN_NAME = "MyPlugin"
|
||||
|
@ -23,8 +27,8 @@ class MyPlugin(RuntimeEnvPlugin):
|
|||
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
|
||||
|
||||
@staticmethod
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
value = runtime_env_dict[MY_PLUGIN_NAME]
|
||||
def validate(runtime_env: RuntimeEnv) -> str:
|
||||
value = runtime_env[MY_PLUGIN_NAME]
|
||||
if value == "fail":
|
||||
raise ValueError("not allowed")
|
||||
return value
|
||||
|
@ -32,10 +36,11 @@ class MyPlugin(RuntimeEnvPlugin):
|
|||
def modify_context(
|
||||
self,
|
||||
uris: List[str],
|
||||
plugin_config_dict: dict,
|
||||
runtime_env: RuntimeEnv,
|
||||
ctx: RuntimeEnvContext,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
plugin_config_dict = runtime_env[MY_PLUGIN_NAME]
|
||||
ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"])
|
||||
ctx.command_prefix.append(
|
||||
f"echo {plugin_config_dict['tmp_content']} > "
|
||||
|
@ -103,14 +108,20 @@ class MyPluginForHang(RuntimeEnvPlugin):
|
|||
def validate(runtime_env_dict: dict) -> str:
|
||||
return "True"
|
||||
|
||||
def create(self, uri: str, runtime_env: dict, ctx: RuntimeEnvContext) -> float:
|
||||
async def create(
|
||||
self,
|
||||
uri: str,
|
||||
runtime_env: dict,
|
||||
ctx: RuntimeEnvContext,
|
||||
logger: logging.Logger,
|
||||
) -> float:
|
||||
global my_plugin_setup_times
|
||||
my_plugin_setup_times += 1
|
||||
|
||||
# first setup
|
||||
if my_plugin_setup_times == 1:
|
||||
# sleep forever
|
||||
sleep(3600)
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
def modify_context(
|
||||
self,
|
||||
|
@ -164,10 +175,6 @@ DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
|
|||
DUMMY_PLUGIN_NAME = "DummyPlugin"
|
||||
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
|
||||
HANG_PLUGIN_NAME = "HangPlugin"
|
||||
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
|
||||
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
|
||||
)
|
||||
DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout"
|
||||
|
||||
|
||||
class DummyPlugin(RuntimeEnvPlugin):
|
||||
|
@ -181,27 +188,21 @@ class DummyPlugin(RuntimeEnvPlugin):
|
|||
class HangPlugin(DummyPlugin):
|
||||
name = HANG_PLUGIN_NAME
|
||||
|
||||
def create(
|
||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||
async def create(
|
||||
self,
|
||||
uri: str,
|
||||
runtime_env: "RuntimeEnv",
|
||||
ctx: RuntimeEnvContext,
|
||||
logger: logging.Logger, # noqa: F821
|
||||
) -> float:
|
||||
sleep(3600)
|
||||
|
||||
|
||||
class DiasbleTimeoutPlugin(DummyPlugin):
|
||||
name = DISABLE_TIMEOUT_PLUGIN_NAME
|
||||
|
||||
def create(
|
||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||
) -> float:
|
||||
sleep(10)
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
'[{"class":"' + DUMMY_PLUGIN_CLASS_PATH + '"},'
|
||||
'{"class":"' + HANG_PLUGIN_CLASS_PATH + '"},'
|
||||
'{"class":"' + DISABLE_TIMEOUT_PLUGIN_CLASS_PATH + '"}]',
|
||||
'{"class":"' + HANG_PLUGIN_CLASS_PATH + '"}]',
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
|
@ -215,7 +216,7 @@ def test_plugin_timeout(set_runtime_env_plugins, start_cluster):
|
|||
f.options(
|
||||
runtime_env={
|
||||
HANG_PLUGIN_NAME: {"name": "f1"},
|
||||
"config": {"setup_timeout_seconds": 10},
|
||||
"config": {"setup_timeout_seconds": 1},
|
||||
}
|
||||
).remote(),
|
||||
f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(),
|
||||
|
@ -397,6 +398,225 @@ def test_unexpected_field_warning(shutdown_only):
|
|||
wait_for_condition(lambda: "unexpected_field is not recognized" in f.read())
|
||||
|
||||
|
||||
URI_CACHING_TEST_PLUGIN_CLASS_PATH = (
|
||||
"ray.tests.test_runtime_env_plugin.UriCachingTestPlugin"
|
||||
)
|
||||
URI_CACHING_TEST_PLUGIN_NAME = "UriCachingTestPlugin"
|
||||
URI_CACHING_TEST_DIR = Path(tempfile.gettempdir()) / "runtime_env_uri_caching_test"
|
||||
uri_caching_test_file_path = URI_CACHING_TEST_DIR / "uri_caching_test_file.json"
|
||||
URI_CACHING_TEST_DIR.mkdir(parents=True, exist_ok=True)
|
||||
uri_caching_test_file_path.write_text("{}")
|
||||
|
||||
|
||||
def get_plugin_usage_data():
|
||||
with open(uri_caching_test_file_path, "r") as f:
|
||||
data = json.loads(f.read())
|
||||
return data
|
||||
|
||||
|
||||
class UriCachingTestPlugin(RuntimeEnvPlugin):
|
||||
"""A plugin that fakes taking up local disk space when creating its environment.
|
||||
|
||||
This plugin is used to test that the URI caching is working correctly.
|
||||
Example:
|
||||
runtime_env = {"UriCachingTestPlugin": {"uri": "file:///a", "size_bytes": 10}}
|
||||
"""
|
||||
|
||||
name = URI_CACHING_TEST_PLUGIN_NAME
|
||||
|
||||
def __init__(self):
|
||||
# Keeps track of the "disk space" each URI takes up for the
|
||||
# UriCachingTestPlugin.
|
||||
self.uris_to_sizes = {}
|
||||
self.modify_context_call_count = 0
|
||||
self.create_call_count = 0
|
||||
|
||||
def write_plugin_usage_data(self) -> None:
|
||||
with open(uri_caching_test_file_path, "w") as f:
|
||||
data = {
|
||||
"uris_to_sizes": self.uris_to_sizes,
|
||||
"modify_context_call_count": self.modify_context_call_count,
|
||||
"create_call_count": self.create_call_count,
|
||||
}
|
||||
f.write(json.dumps(data))
|
||||
|
||||
def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F811
|
||||
return [runtime_env[self.name]["uri"]]
|
||||
|
||||
async def create(
|
||||
self,
|
||||
uri,
|
||||
runtime_env: "RuntimeEnv", # noqa: F821
|
||||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger,
|
||||
) -> float:
|
||||
self.create_call_count += 1
|
||||
created_size_bytes = runtime_env[self.name]["size_bytes"]
|
||||
self.uris_to_sizes[uri] = created_size_bytes
|
||||
self.write_plugin_usage_data()
|
||||
return created_size_bytes
|
||||
|
||||
def modify_context(
|
||||
self,
|
||||
uris: List[str],
|
||||
runtime_env: "RuntimeEnv",
|
||||
context: RuntimeEnvContext,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self.modify_context_call_count += 1
|
||||
self.write_plugin_usage_data()
|
||||
|
||||
def delete_uri(self, uri: str, logger: logging.Logger) -> int:
|
||||
size = self.uris_to_sizes.pop(uri)
|
||||
self.write_plugin_usage_data()
|
||||
return size
|
||||
|
||||
|
||||
# Set scope to "class" to force this to run before start_cluster, whose scope
|
||||
# is "function". We need these env vars to be set before Ray is started.
|
||||
@pytest.fixture(scope="class")
|
||||
def uri_cache_size_100_gb():
|
||||
var = f"RAY_RUNTIME_ENV_{URI_CACHING_TEST_PLUGIN_NAME}_CACHE_SIZE_GB".upper()
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
var: "100",
|
||||
},
|
||||
):
|
||||
print("Set URI cache size for UriCachingTestPlugin to 100 GB")
|
||||
yield
|
||||
|
||||
|
||||
def gb_to_bytes(size_gb: int) -> int:
|
||||
return size_gb * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class TestGC:
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
json.dumps([{"class": URI_CACHING_TEST_PLUGIN_CLASS_PATH}]),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_uri_caching(
|
||||
self, set_runtime_env_plugins, start_cluster, uri_cache_size_100_gb
|
||||
):
|
||||
cluster, address = start_cluster
|
||||
|
||||
ray.init(address=address)
|
||||
|
||||
def reinit():
|
||||
ray.shutdown()
|
||||
# TODO(architkulkarni): Currently, reinit the driver will generate the same
|
||||
# job id. And if we reinit immediately after shutdown, raylet may
|
||||
# process new job started before old job finished in some cases. This
|
||||
# inconsistency could disorder the URI reference and delete a valid
|
||||
# runtime env. We sleep here to walk around this issue.
|
||||
time.sleep(5)
|
||||
ray.init(address=address)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return True
|
||||
|
||||
# Run a task to trigger runtime_env creation.
|
||||
ref1 = f.options(
|
||||
runtime_env={
|
||||
URI_CACHING_TEST_PLUGIN_NAME: {
|
||||
"uri": "file:///tmp/test_uri_1",
|
||||
"size_bytes": gb_to_bytes(50),
|
||||
}
|
||||
}
|
||||
).remote()
|
||||
ray.get(ref1)
|
||||
# Check that the URI was "created on disk".
|
||||
print(get_plugin_usage_data())
|
||||
wait_for_condition(
|
||||
lambda: get_plugin_usage_data()
|
||||
== {
|
||||
"uris_to_sizes": {"file:///tmp/test_uri_1": gb_to_bytes(50)},
|
||||
"modify_context_call_count": 1,
|
||||
"create_call_count": 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Shutdown ray to stop the worker and remove the runtime_env reference.
|
||||
reinit()
|
||||
|
||||
# Run a task with a different runtime env.
|
||||
ref2 = f.options(
|
||||
runtime_env={
|
||||
URI_CACHING_TEST_PLUGIN_NAME: {
|
||||
"uri": "file:///tmp/test_uri_2",
|
||||
"size_bytes": gb_to_bytes(51),
|
||||
}
|
||||
}
|
||||
).remote()
|
||||
ray.get(ref2)
|
||||
# This should delete the old URI and create a new one, because 50 + 51 > 100
|
||||
# and the cache size limit is 100.
|
||||
wait_for_condition(
|
||||
lambda: get_plugin_usage_data()
|
||||
== {
|
||||
"uris_to_sizes": {"file:///tmp/test_uri_2": gb_to_bytes(51)},
|
||||
"modify_context_call_count": 2,
|
||||
"create_call_count": 2,
|
||||
}
|
||||
)
|
||||
|
||||
reinit()
|
||||
|
||||
# Run a task with the cached runtime env, to check that the runtime env is not
|
||||
# created anew.
|
||||
ref3 = f.options(
|
||||
runtime_env={
|
||||
URI_CACHING_TEST_PLUGIN_NAME: {
|
||||
"uri": "file:///tmp/test_uri_2",
|
||||
"size_bytes": gb_to_bytes(51),
|
||||
}
|
||||
}
|
||||
).remote()
|
||||
ray.get(ref3)
|
||||
# modify_context should still be called even if create() is not called.
|
||||
# Example: for a "conda" plugin, even if the conda env is already created
|
||||
# and cached, we still need to call modify_context to add "conda activate" to
|
||||
# the RuntimeEnvContext.command_prefix for the worker.
|
||||
wait_for_condition(
|
||||
lambda: get_plugin_usage_data()
|
||||
== {
|
||||
"uris_to_sizes": {"file:///tmp/test_uri_2": gb_to_bytes(51)},
|
||||
"modify_context_call_count": 3,
|
||||
"create_call_count": 2,
|
||||
}
|
||||
)
|
||||
|
||||
reinit()
|
||||
|
||||
# Run a task with a new runtime env
|
||||
ref4 = f.options(
|
||||
runtime_env={
|
||||
URI_CACHING_TEST_PLUGIN_NAME: {
|
||||
"uri": "file:///tmp/test_uri_3",
|
||||
"size_bytes": gb_to_bytes(10),
|
||||
}
|
||||
}
|
||||
).remote()
|
||||
ray.get(ref4)
|
||||
# The last two URIs should still be present in the cache, because 51 + 10 < 100.
|
||||
wait_for_condition(
|
||||
lambda: get_plugin_usage_data()
|
||||
== {
|
||||
"uris_to_sizes": {
|
||||
"file:///tmp/test_uri_2": gb_to_bytes(51),
|
||||
"file:///tmp/test_uri_3": gb_to_bytes(10),
|
||||
},
|
||||
"modify_context_call_count": 4,
|
||||
"create_call_count": 3,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue