[runtime env] Add URI support for plugins (#26746)

This commit is contained in:
Archit Kulkarni 2022-07-27 09:28:19 -07:00 committed by GitHub
parent 274ea2a1ba
commit 60f33777a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 448 additions and 251 deletions

View file

@ -6,7 +6,6 @@ import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, List, Set, Tuple
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
@ -20,10 +19,12 @@ from ray._private.runtime_env.container import ContainerManager
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.java_jars import JavaJarsPlugin
from ray._private.runtime_env.pip import PipPlugin
from ray._private.runtime_env.plugin import PluginCacheManager
from ray._private.runtime_env.plugin import (
RuntimeEnvPlugin,
create_for_plugin_if_needed,
)
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
from ray._private.runtime_env.py_modules import PyModulesPlugin
from ray._private.runtime_env.uri_cache import URICache
from ray._private.runtime_env.working_dir import WorkingDirPlugin
from ray.core.generated import (
agent_manager_pb2,
@ -53,12 +54,8 @@ class CreatedEnvResult:
creation_time_ms: int
class UriType(Enum):
WORKING_DIR = "working_dir"
PY_MODULES = "py_modules"
PIP = "pip"
CONDA = "conda"
JAVA_JARS = "java_jars"
# e.g., "working_dir"
UriType = str
class ReferenceTable:
@ -198,48 +195,37 @@ class RuntimeEnvAgent(
# TODO(architkulkarni): "base plugins" and third-party plugins should all go
# through the same code path. We should never need to refer to
# self._xxx_plugin, we should just iterate through self._plugins.
self._base_plugins = [
self._base_plugins: List[RuntimeEnvPlugin] = [
self._working_dir_plugin,
self._pip_plugin,
self._conda_plugin,
self._py_modules_plugin,
self._java_jars_plugin,
]
self._uri_caches = {}
self._base_plugin_cache_managers = {}
self._plugin_manager = RuntimeEnvPluginManager()
for plugin in self._base_plugins:
# Set the max size for the cache. Defaults to 10 GB.
cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
cache_size_bytes = int(
(1024 ** 3) * float(os.environ.get(cache_size_env_var, 10))
)
self._uri_caches[plugin.name] = URICache(
plugin.delete_uri, cache_size_bytes
)
self._base_plugin_cache_managers[plugin.name] = PluginCacheManager(
plugin, self._uri_caches[plugin.name]
)
self._plugin_manager.add_plugin(plugin)
self._reference_table = ReferenceTable(
self.uris_parser,
self.unused_uris_processor,
self.unused_runtime_env_processor,
)
self._runtime_env_plugin_manager = RuntimeEnvPluginManager()
self._logger = default_logger
def uris_parser(self, runtime_env):
result = list()
for plugin in self._base_plugins:
for name, plugin_setup_context in self._plugin_manager.plugins.items():
plugin = plugin_setup_context.class_instance
uris = plugin.get_uris(runtime_env)
for uri in uris:
result.append((uri, UriType(plugin.name)))
result.append((uri, UriType(name)))
return result
def unused_uris_processor(self, unused_uris: List[Tuple[str, UriType]]) -> None:
for uri, uri_type in unused_uris:
self._uri_caches[uri_type.value].mark_unused(uri)
self._plugin_manager.plugins[str(uri_type)].uri_cache.mark_unused(uri)
def unused_runtime_env_processor(self, unused_runtime_env: str) -> None:
def delete_runtime_env():
@ -275,7 +261,9 @@ class RuntimeEnvAgent(
)
async def _setup_runtime_env(
runtime_env, serialized_runtime_env, serialized_allocated_resource_instances
runtime_env: RuntimeEnv,
serialized_runtime_env,
serialized_allocated_resource_instances,
):
allocated_resource: dict = json.loads(
serialized_allocated_resource_instances or "{}"
@ -290,39 +278,25 @@ class RuntimeEnvAgent(
runtime_env, context, logger=per_job_logger
)
for manager in self._base_plugin_cache_managers.values():
await manager.create_if_needed(
runtime_env, context, logger=per_job_logger
# Warn about unrecognized fields in the runtime env.
for name, _ in runtime_env.plugins():
if name not in self._plugin_manager.plugins:
per_job_logger.warning(
f"runtime_env field {name} is not recognized by "
"Ray and will be ignored. In the future, unrecognized "
"fields in the runtime_env will raise an exception."
)
"""Run setup for each plugin unless it has already been cached."""
for (
plugin_setup_context
) in self._plugin_manager.sorted_plugin_setup_contexts():
plugin = plugin_setup_context.class_instance
uri_cache = plugin_setup_context.uri_cache
await create_for_plugin_if_needed(
runtime_env, plugin, uri_cache, context, per_job_logger
)
def setup_plugins():
# Run setup function from all the plugins
if runtime_env.plugins():
for (
setup_context
) in self._runtime_env_plugin_manager.sorted_plugin_setup_contexts(
runtime_env.plugins()
):
per_job_logger.debug(
f"Setting up runtime env plugin {setup_context.name}"
)
# TODO(architkulkarni): implement uri support
setup_context.class_instance.validate(runtime_env)
setup_context.class_instance.create(
"uri not implemented", setup_context.config, context
)
setup_context.class_instance.modify_context(
"uri not implemented",
setup_context.config,
context,
per_job_logger,
)
loop = asyncio.get_event_loop()
# Plugins setup method is sync process, running in other threads
# is to avoid blocking asyncio loop
await loop.run_in_executor(None, setup_plugins)
return context
async def _create_runtime_env_with_retry(

View file

@ -12,10 +12,10 @@ def test_reference_table():
def uris_parser(runtime_env) -> Tuple[str, UriType]:
result = list()
result.append((runtime_env.working_dir(), UriType.WORKING_DIR))
result.append((runtime_env.working_dir(), "working_dir"))
py_module_uris = runtime_env.py_modules()
for uri in py_module_uris:
result.append((uri, UriType.PY_MODULES))
result.append((uri, "py_modules"))
return result
def unused_uris_processor(unused_uris: List[Tuple[str, UriType]]) -> None:
@ -57,8 +57,8 @@ def test_reference_table():
)
# Remove runtime env 1
expected_unused_uris.append(("s3://working_dir_1.zip", UriType.WORKING_DIR))
expected_unused_uris.append(("s3://py_module_B.zip", UriType.PY_MODULES))
expected_unused_uris.append(("s3://working_dir_1.zip", "working_dir"))
expected_unused_uris.append(("s3://py_module_B.zip", "py_modules"))
expected_unused_runtime_env = runtime_env_1.serialize()
reference_table.decrease_reference(
runtime_env_1, runtime_env_1.serialize(), "raylet"
@ -67,9 +67,9 @@ def test_reference_table():
assert not expected_unused_runtime_env
# Remove runtime env 2
expected_unused_uris.append(("s3://working_dir_2.zip", UriType.WORKING_DIR))
expected_unused_uris.append(("s3://py_module_A.zip", UriType.PY_MODULES))
expected_unused_uris.append(("s3://py_module_C.zip", UriType.PY_MODULES))
expected_unused_uris.append(("s3://working_dir_2.zip", "working_dir"))
expected_unused_uris.append(("s3://py_module_A.zip", "py_modules"))
expected_unused_uris.append(("s3://py_module_C.zip", "py_modules"))
expected_unused_runtime_env = runtime_env_2.serialize()
reference_table.decrease_reference(
runtime_env_2, runtime_env_2.serialize(), "raylet"

View file

@ -315,6 +315,13 @@ class CondaPlugin(RuntimeEnvPlugin):
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
) -> int:
if uri is None:
# The "conda" field is the name of an existing conda env, so no
# need to create one.
# TODO(architkulkarni): Try "conda activate" here to see if the
# env exists, and raise an exception if it doesn't.
return 0
# Currently create method is still a sync process, to avoid blocking
# the loop, need to run this function in another thread.
# TODO(Catch-Bull): Refactor method create into an async process, and

View file

@ -457,7 +457,7 @@ class PipPlugin(RuntimeEnvPlugin):
uris: List[str],
runtime_env: "RuntimeEnv", # noqa: F821
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
logger: logging.Logger = default_logger,
):
if not runtime_env.has_pip():
return

View file

@ -2,7 +2,7 @@ import logging
import os
import json
from abc import ABC
from typing import List, Dict, Tuple, Any
from typing import List, Dict, Optional, Any, Type
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.uri_cache import URICache
@ -28,28 +28,25 @@ class RuntimeEnvPlugin(ABC):
priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY
@staticmethod
def validate(runtime_env_dict: dict) -> str:
"""Validate user entry and returns a URI uniquely describing resource.
This method will be called at ``f.options(runtime_env=...)`` or
``ray.init(runtime_env=...)`` time and it should check the runtime env
dictionary for any errors. For example, it can raise "TypeError:
expected string for "conda" field".
def validate(runtime_env_dict: dict) -> None:
"""Validate user entry for this plugin.
Args:
runtime_env_dict(dict): the entire dictionary passed in by user.
Returns:
uri(str): a URI uniquely describing this resource (e.g., a hash of
the conda spec).
runtime_env_dict: the user-supplied runtime environment dict.
Raises:
ValueError: if the validation fails.
"""
raise NotImplementedError()
pass
def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
return None
return []
def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
async def create(
self,
uri: Optional[str],
runtime_env: "RuntimeEnv", # noqa: F821
context: RuntimeEnvContext,
logger: logging.Logger,
) -> float:
"""Create and install the runtime environment.
@ -57,9 +54,10 @@ class RuntimeEnvPlugin(ABC):
used as a caching mechanism.
Args:
uri(str): a URI uniquely describing this resource.
runtime_env(RuntimeEnv): the runtime env protobuf.
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
uri: A URI uniquely describing this resource.
runtime_env: The RuntimeEnv object.
context: auxiliary information supplied by Ray.
logger: A logger to log messages during the context modification.
Returns:
the disk space taken up by this plugin installation for this
@ -81,9 +79,10 @@ class RuntimeEnvPlugin(ABC):
startup, or add new environment variables.
Args:
uris(List[str]): a URIs used by this resource.
runtime_env(RuntimeEnv): the runtime env protobuf.
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
uris: The URIs used by this resource.
runtime_env: The RuntimeEnv object.
context: Auxiliary information supplied by Ray.
logger: A logger to log messages during the context modification.
"""
return
@ -91,8 +90,7 @@ class RuntimeEnvPlugin(ABC):
"""Delete the the runtime environment given uri.
Args:
uri(str): a URI uniquely describing this resource.
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
uri: a URI uniquely describing this resource.
Returns:
the amount of space reclaimed by the deletion.
@ -101,29 +99,59 @@ class RuntimeEnvPlugin(ABC):
class PluginSetupContext:
def __init__(self, name: str, config: Any, class_instance: object):
def __init__(
self,
name: str,
class_instance: RuntimeEnvPlugin,
priority: int,
uri_cache: URICache,
):
self.name = name
self.config = config
self.class_instance = class_instance
self.priority = priority
self.uri_cache = uri_cache
class RuntimeEnvPluginManager:
"""This manager is used to load plugins in runtime env agent."""
class Context:
def __init__(self, class_instance, priority):
self.class_instance = class_instance
self.priority = priority
def __init__(self):
self.plugins: Dict[str, RuntimeEnvPluginManager.Context] = {}
self.plugins: Dict[str, PluginSetupContext] = {}
plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
if plugin_config_str:
plugin_configs = json.loads(plugin_config_str)
self.load_plugins(plugin_configs)
def load_plugins(self, plugin_configs: List[Dict]):
"""Load runtime env plugins"""
def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None:
if not issubclass(plugin_class, RuntimeEnvPlugin):
raise RuntimeError(
f"Invalid runtime env plugin class {plugin_class}. "
"The plugin class must inherit "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
)
if not plugin_class.name:
raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.")
if plugin_class.name in self.plugins:
raise RuntimeError(
f"The name of runtime env plugin {plugin_class} conflicts "
f"with {self.plugins[plugin_class.name]}.",
)
def validate_priority(self, priority: Any) -> None:
if (
not isinstance(priority, int)
or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
):
raise RuntimeError(
f"Invalid runtime env priority {priority}, "
"it should be an integer between "
f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
)
def load_plugins(self, plugin_configs: List[Dict]) -> None:
"""Load runtime env plugins and create URI caches for them."""
for plugin_config in plugin_configs:
if (
not isinstance(plugin_config, dict)
@ -135,21 +163,7 @@ class RuntimeEnvPluginManager:
f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field."
)
plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME])
if not issubclass(plugin_class, RuntimeEnvPlugin):
raise RuntimeError(
f"Invalid runtime env plugin class {plugin_class}. "
"The plugin class must inherit "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
)
if not plugin_class.name:
raise RuntimeError(
f"No valid name in runtime env plugin {plugin_class}."
)
if plugin_class.name in self.plugins:
raise RuntimeError(
f"The name of runtime env plugin {plugin_class} conflicts "
f"with {self.plugins[plugin_class.name]}.",
)
self.validate_plugin_class(plugin_class)
# The priority should be an integer between 0 and 100.
# The default priority is 10. A smaller number indicates a
@ -158,73 +172,86 @@ class RuntimeEnvPluginManager:
priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME]
else:
priority = plugin_class.priority
if (
not isinstance(priority, int)
or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
):
raise RuntimeError(
f"Invalid runtime env priority {priority}, "
"it should be an integer between "
f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
)
self.validate_priority(priority)
self.plugins[plugin_class.name] = RuntimeEnvPluginManager.Context(
plugin_class(), priority
class_instance = plugin_class()
self.plugins[plugin_class.name] = PluginSetupContext(
plugin_class.name,
class_instance,
priority,
self.create_uri_cache_for_plugin(class_instance),
)
def sorted_plugin_setup_contexts(
self, inputs: List[Tuple[str, Any]]
) -> List[PluginSetupContext]:
used_plugins = []
for name, config in inputs:
if name not in self.plugins:
default_logger.error(
f"runtime_env field {name} is not recognized by "
"Ray and will be ignored. In the future, unrecognized "
"fields in the runtime_env will raise an exception."
)
continue
used_plugins.append(
(
name,
config,
self.plugins[name].class_instance,
self.plugins[name].priority,
)
)
sort_used_plugins = sorted(used_plugins, key=lambda x: x[3], reverse=False)
return [
PluginSetupContext(name, config, class_instance)
for name, config, class_instance, _ in sort_used_plugins
]
def add_plugin(self, plugin: RuntimeEnvPlugin) -> None:
"""Add a plugin to the manager and create a URI cache for it.
Args:
plugin: The class instance of the plugin.
"""
plugin_class = type(plugin)
self.validate_plugin_class(plugin_class)
self.validate_priority(plugin_class.priority)
self.plugins[plugin_class.name] = PluginSetupContext(
plugin_class.name,
plugin,
plugin_class.priority,
self.create_uri_cache_for_plugin(plugin),
)
def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache:
"""Create a URI cache for a plugin.
Args:
plugin_name: The name of the plugin.
Returns:
The created URI cache for the plugin.
"""
# Set the max size for the cache. Defaults to 10 GB.
cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
cache_size_bytes = int(
(1024 ** 3) * float(os.environ.get(cache_size_env_var, 10))
)
return URICache(plugin.delete_uri, cache_size_bytes)
def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]:
"""Get the sorted plugin setup contexts, sorted by increasing priority.
Returns:
The sorted plugin setup contexts.
"""
return sorted(self.plugins.values(), key=lambda x: x.priority)
@DeveloperAPI
class PluginCacheManager:
"""Manages a plugin and a cache for its local resources."""
async def create_for_plugin_if_needed(
runtime_env,
plugin: RuntimeEnvPlugin,
uri_cache: URICache,
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
):
"""Set up the environment using the plugin if not already set up and cached."""
if plugin.name not in runtime_env or runtime_env[plugin.name] is None:
return
def __init__(self, plugin: RuntimeEnvPlugin, uri_cache: URICache):
self._plugin = plugin
self._uri_cache = uri_cache
plugin.validate(runtime_env)
async def create_if_needed(
self,
runtime_env: "RuntimeEnv", # noqa: F821
context: RuntimeEnvContext,
logger: logging.Logger = default_logger,
):
uris = self._plugin.get_uris(runtime_env)
for uri in uris:
if uri not in self._uri_cache:
logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await self._plugin.create(
uri, runtime_env, context, logger=logger
)
self._uri_cache.add(uri, size_bytes, logger=logger)
else:
logger.debug(f"Cache hit for URI {uri}.")
self._uri_cache.mark_used(uri, logger=logger)
uris = plugin.get_uris(runtime_env)
self._plugin.modify_context(uris, runtime_env, context)
if not uris:
logger.debug(
f"No URIs for runtime env plugin {plugin.name}; "
"create always without checking the cache."
)
await plugin.create(None, runtime_env, context, logger=logger)
for uri in uris:
if uri not in uri_cache:
logger.debug(f"Cache miss for URI {uri}.")
size_bytes = await plugin.create(uri, runtime_env, context, logger=logger)
uri_cache.add(uri, size_bytes, logger=logger)
else:
logger.debug(f"Cache hit for URI {uri}.")
uri_cache.mark_used(uri, logger=logger)
plugin.modify_context(uris, runtime_env, context, logger)

View file

@ -134,10 +134,10 @@ class WorkingDirPlugin(RuntimeEnvPlugin):
async def create(
self,
uri: str,
uri: Optional[str],
runtime_env: dict,
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
logger: logging.Logger = default_logger,
) -> int:
local_dir = await download_and_unpack_package(
uri, self._resources_dir, self._gcs_aio_client, logger=logger
@ -145,7 +145,11 @@ class WorkingDirPlugin(RuntimeEnvPlugin):
return get_directory_size_bytes(local_dir)
def modify_context(
self, uris: List[str], runtime_env_dict: Dict, context: RuntimeEnvContext
self,
uris: List[str],
runtime_env_dict: Dict,
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
):
if not uris:
return

View file

@ -1445,12 +1445,10 @@ def get_runtime_env_info(
proto_runtime_env_info = ProtoRuntimeEnvInfo()
if runtime_env.get_working_dir_uri():
proto_runtime_env_info.uris.working_dir_uri = runtime_env.get_working_dir_uri()
if len(runtime_env.get_py_modules_uris()) > 0:
proto_runtime_env_info.uris.py_modules_uris[
:
] = runtime_env.get_py_modules_uris()
if runtime_env.working_dir_uri():
proto_runtime_env_info.uris.working_dir_uri = runtime_env.working_dir_uri()
if len(runtime_env.py_modules_uris()) > 0:
proto_runtime_env_info.uris.py_modules_uris[:] = runtime_env.py_modules_uris()
# TODO(Catch-Bull): overload `__setitem__` for `RuntimeEnv`, change the
# runtime_env of all internal code from dict to RuntimeEnv.

View file

@ -2044,7 +2044,7 @@ def connect(
)
# In client mode, if we use runtime envs with "working_dir", then
# it'll be handled automatically. Otherwise, add the current dir.
if not job_config.client_job and not job_config.runtime_env_has_uris():
if not job_config.client_job and not job_config.runtime_env_has_working_dir():
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(
lambda worker_info: sys.path.insert(1, current_directory)

View file

@ -126,9 +126,8 @@ class JobConfig:
return self._cached_pb
def runtime_env_has_uris(self):
"""Whether there are uris in runtime env or not"""
return self._validate_runtime_env().has_uris()
def runtime_env_has_working_dir(self):
return self._validate_runtime_env().has_working_dir()
def get_serialized_runtime_env(self) -> str:
"""Return the JSON-serialized parsed runtime env dict"""

View file

@ -345,30 +345,6 @@ class RuntimeEnv(dict):
if all(val is None for val in self.values()):
self.clear()
def get_uris(self) -> List[str]:
# TODO(architkulkarni): this should programmatically be extended with
# URIs from all plugins.
plugin_uris = []
if "working_dir" in self:
plugin_uris.append(self["working_dir"])
if "py_modules" in self:
for uri in self["py_modules"]:
plugin_uris.append(uri)
if "conda" in self:
uri = get_conda_uri(self)
if uri is not None:
plugin_uris.append(uri)
if "pip" in self:
uri = get_pip_uri(self)
if uri is not None:
plugin_uris.append(uri)
def get_working_dir_uri(self) -> str:
return self.get("working_dir", None)
def get_py_modules_uris(self) -> List[str]:
return self.get("py_modules", [])
def __setitem__(self, key: str, value: Any) -> None:
if is_dataclass(value):
jsonable_type = asdict(value)
@ -416,16 +392,8 @@ class RuntimeEnv(dict):
return runtime_env_dict
def has_uris(self) -> bool:
if (
self.working_dir_uri()
or self.py_modules_uris()
or self.conda_uri()
or self.pip_uri()
or self.plugin_uris()
):
return True
return False
def has_working_dir(self) -> bool:
return self.get("working_dir") is not None
def working_dir_uri(self) -> Optional[str]:
return self.get("working_dir")

View file

@ -1,9 +1,12 @@
import asyncio
import logging
import os
from pathlib import Path
import time
from unittest import mock
import tempfile
import json
from time import sleep
from typing import List
import pytest
@ -13,6 +16,7 @@ from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.test_utils import enable_external_redis, wait_for_condition
from ray.exceptions import RuntimeEnvSetupError
from ray.runtime_env.runtime_env import RuntimeEnv
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
MY_PLUGIN_NAME = "MyPlugin"
@ -23,8 +27,8 @@ class MyPlugin(RuntimeEnvPlugin):
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
@staticmethod
def validate(runtime_env_dict: dict) -> str:
value = runtime_env_dict[MY_PLUGIN_NAME]
def validate(runtime_env: RuntimeEnv) -> str:
value = runtime_env[MY_PLUGIN_NAME]
if value == "fail":
raise ValueError("not allowed")
return value
@ -32,10 +36,11 @@ class MyPlugin(RuntimeEnvPlugin):
def modify_context(
self,
uris: List[str],
plugin_config_dict: dict,
runtime_env: RuntimeEnv,
ctx: RuntimeEnvContext,
logger: logging.Logger,
) -> None:
plugin_config_dict = runtime_env[MY_PLUGIN_NAME]
ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"])
ctx.command_prefix.append(
f"echo {plugin_config_dict['tmp_content']} > "
@ -103,14 +108,20 @@ class MyPluginForHang(RuntimeEnvPlugin):
def validate(runtime_env_dict: dict) -> str:
return "True"
def create(self, uri: str, runtime_env: dict, ctx: RuntimeEnvContext) -> float:
async def create(
self,
uri: str,
runtime_env: dict,
ctx: RuntimeEnvContext,
logger: logging.Logger,
) -> float:
global my_plugin_setup_times
my_plugin_setup_times += 1
# first setup
if my_plugin_setup_times == 1:
# sleep forever
sleep(3600)
await asyncio.sleep(3600)
def modify_context(
self,
@ -164,10 +175,6 @@ DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
DUMMY_PLUGIN_NAME = "DummyPlugin"
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
HANG_PLUGIN_NAME = "HangPlugin"
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
)
DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout"
class DummyPlugin(RuntimeEnvPlugin):
@ -181,27 +188,21 @@ class DummyPlugin(RuntimeEnvPlugin):
class HangPlugin(DummyPlugin):
name = HANG_PLUGIN_NAME
def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
async def create(
self,
uri: str,
runtime_env: "RuntimeEnv",
ctx: RuntimeEnvContext,
logger: logging.Logger, # noqa: F821
) -> float:
sleep(3600)
class DiasbleTimeoutPlugin(DummyPlugin):
name = DISABLE_TIMEOUT_PLUGIN_NAME
def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
) -> float:
sleep(10)
await asyncio.sleep(3600)
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
'[{"class":"' + DUMMY_PLUGIN_CLASS_PATH + '"},'
'{"class":"' + HANG_PLUGIN_CLASS_PATH + '"},'
'{"class":"' + DISABLE_TIMEOUT_PLUGIN_CLASS_PATH + '"}]',
'{"class":"' + HANG_PLUGIN_CLASS_PATH + '"}]',
],
indirect=True,
)
@ -215,7 +216,7 @@ def test_plugin_timeout(set_runtime_env_plugins, start_cluster):
f.options(
runtime_env={
HANG_PLUGIN_NAME: {"name": "f1"},
"config": {"setup_timeout_seconds": 10},
"config": {"setup_timeout_seconds": 1},
}
).remote(),
f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(),
@ -397,6 +398,225 @@ def test_unexpected_field_warning(shutdown_only):
wait_for_condition(lambda: "unexpected_field is not recognized" in f.read())
URI_CACHING_TEST_PLUGIN_CLASS_PATH = (
"ray.tests.test_runtime_env_plugin.UriCachingTestPlugin"
)
URI_CACHING_TEST_PLUGIN_NAME = "UriCachingTestPlugin"
URI_CACHING_TEST_DIR = Path(tempfile.gettempdir()) / "runtime_env_uri_caching_test"
uri_caching_test_file_path = URI_CACHING_TEST_DIR / "uri_caching_test_file.json"
URI_CACHING_TEST_DIR.mkdir(parents=True, exist_ok=True)
uri_caching_test_file_path.write_text("{}")
def get_plugin_usage_data():
with open(uri_caching_test_file_path, "r") as f:
data = json.loads(f.read())
return data
class UriCachingTestPlugin(RuntimeEnvPlugin):
"""A plugin that fakes taking up local disk space when creating its environment.
This plugin is used to test that the URI caching is working correctly.
Example:
runtime_env = {"UriCachingTestPlugin": {"uri": "file:///a", "size_bytes": 10}}
"""
name = URI_CACHING_TEST_PLUGIN_NAME
def __init__(self):
# Keeps track of the "disk space" each URI takes up for the
# UriCachingTestPlugin.
self.uris_to_sizes = {}
self.modify_context_call_count = 0
self.create_call_count = 0
def write_plugin_usage_data(self) -> None:
with open(uri_caching_test_file_path, "w") as f:
data = {
"uris_to_sizes": self.uris_to_sizes,
"modify_context_call_count": self.modify_context_call_count,
"create_call_count": self.create_call_count,
}
f.write(json.dumps(data))
def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F811
return [runtime_env[self.name]["uri"]]
async def create(
self,
uri,
runtime_env: "RuntimeEnv", # noqa: F821
context: RuntimeEnvContext,
logger: logging.Logger,
) -> float:
self.create_call_count += 1
created_size_bytes = runtime_env[self.name]["size_bytes"]
self.uris_to_sizes[uri] = created_size_bytes
self.write_plugin_usage_data()
return created_size_bytes
def modify_context(
self,
uris: List[str],
runtime_env: "RuntimeEnv",
context: RuntimeEnvContext,
logger: logging.Logger,
) -> None:
self.modify_context_call_count += 1
self.write_plugin_usage_data()
def delete_uri(self, uri: str, logger: logging.Logger) -> int:
size = self.uris_to_sizes.pop(uri)
self.write_plugin_usage_data()
return size
# Set scope to "class" to force this to run before start_cluster, whose scope
# is "function". We need these env vars to be set before Ray is started.
@pytest.fixture(scope="class")
def uri_cache_size_100_gb():
var = f"RAY_RUNTIME_ENV_{URI_CACHING_TEST_PLUGIN_NAME}_CACHE_SIZE_GB".upper()
with mock.patch.dict(
os.environ,
{
var: "100",
},
):
print("Set URI cache size for UriCachingTestPlugin to 100 GB")
yield
def gb_to_bytes(size_gb: int) -> int:
return size_gb * 1024 * 1024 * 1024
class TestGC:
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
json.dumps([{"class": URI_CACHING_TEST_PLUGIN_CLASS_PATH}]),
],
indirect=True,
)
def test_uri_caching(
self, set_runtime_env_plugins, start_cluster, uri_cache_size_100_gb
):
cluster, address = start_cluster
ray.init(address=address)
def reinit():
ray.shutdown()
# TODO(architkulkarni): Currently, reinit the driver will generate the same
# job id. And if we reinit immediately after shutdown, raylet may
# process new job started before old job finished in some cases. This
# inconsistency could disorder the URI reference and delete a valid
# runtime env. We sleep here to walk around this issue.
time.sleep(5)
ray.init(address=address)
@ray.remote
def f():
return True
# Run a task to trigger runtime_env creation.
ref1 = f.options(
runtime_env={
URI_CACHING_TEST_PLUGIN_NAME: {
"uri": "file:///tmp/test_uri_1",
"size_bytes": gb_to_bytes(50),
}
}
).remote()
ray.get(ref1)
# Check that the URI was "created on disk".
print(get_plugin_usage_data())
wait_for_condition(
lambda: get_plugin_usage_data()
== {
"uris_to_sizes": {"file:///tmp/test_uri_1": gb_to_bytes(50)},
"modify_context_call_count": 1,
"create_call_count": 1,
}
)
# Shutdown ray to stop the worker and remove the runtime_env reference.
reinit()
# Run a task with a different runtime env.
ref2 = f.options(
runtime_env={
URI_CACHING_TEST_PLUGIN_NAME: {
"uri": "file:///tmp/test_uri_2",
"size_bytes": gb_to_bytes(51),
}
}
).remote()
ray.get(ref2)
# This should delete the old URI and create a new one, because 50 + 51 > 100
# and the cache size limit is 100.
wait_for_condition(
lambda: get_plugin_usage_data()
== {
"uris_to_sizes": {"file:///tmp/test_uri_2": gb_to_bytes(51)},
"modify_context_call_count": 2,
"create_call_count": 2,
}
)
reinit()
# Run a task with the cached runtime env, to check that the runtime env is not
# created anew.
ref3 = f.options(
runtime_env={
URI_CACHING_TEST_PLUGIN_NAME: {
"uri": "file:///tmp/test_uri_2",
"size_bytes": gb_to_bytes(51),
}
}
).remote()
ray.get(ref3)
# modify_context should still be called even if create() is not called.
# Example: for a "conda" plugin, even if the conda env is already created
# and cached, we still need to call modify_context to add "conda activate" to
# the RuntimeEnvContext.command_prefix for the worker.
wait_for_condition(
lambda: get_plugin_usage_data()
== {
"uris_to_sizes": {"file:///tmp/test_uri_2": gb_to_bytes(51)},
"modify_context_call_count": 3,
"create_call_count": 2,
}
)
reinit()
# Run a task with a new runtime env
ref4 = f.options(
runtime_env={
URI_CACHING_TEST_PLUGIN_NAME: {
"uri": "file:///tmp/test_uri_3",
"size_bytes": gb_to_bytes(10),
}
}
).remote()
ray.get(ref4)
# The last two URIs should still be present in the cache, because 51 + 10 < 100.
wait_for_condition(
lambda: get_plugin_usage_data()
== {
"uris_to_sizes": {
"file:///tmp/test_uri_2": gb_to_bytes(51),
"file:///tmp/test_uri_3": gb_to_bytes(10),
},
"modify_context_call_count": 4,
"create_call_count": 3,
}
)
if __name__ == "__main__":
import sys