From 60f33777a25675fef291c8ff1a90b29a7dfc0928 Mon Sep 17 00:00:00 2001 From: Archit Kulkarni Date: Wed, 27 Jul 2022 09:28:19 -0700 Subject: [PATCH] [runtime env] Add URI support for plugins (#26746) --- .../modules/runtime_env/runtime_env_agent.py | 92 +++--- .../tests/test_runtime_env_agent.py | 14 +- python/ray/_private/runtime_env/conda.py | 7 + python/ray/_private/runtime_env/pip.py | 2 +- python/ray/_private/runtime_env/plugin.py | 251 ++++++++-------- .../ray/_private/runtime_env/working_dir.py | 10 +- python/ray/_private/utils.py | 10 +- python/ray/_private/worker.py | 2 +- python/ray/job_config.py | 5 +- python/ray/runtime_env/runtime_env.py | 36 +-- python/ray/tests/test_runtime_env_plugin.py | 270 ++++++++++++++++-- 11 files changed, 448 insertions(+), 251 deletions(-) diff --git a/dashboard/modules/runtime_env/runtime_env_agent.py b/dashboard/modules/runtime_env/runtime_env_agent.py index 4241776aa..e1bec7926 100644 --- a/dashboard/modules/runtime_env/runtime_env_agent.py +++ b/dashboard/modules/runtime_env/runtime_env_agent.py @@ -6,7 +6,6 @@ import time import traceback from collections import defaultdict from dataclasses import dataclass -from enum import Enum from typing import Callable, Dict, List, Set, Tuple from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS @@ -20,10 +19,12 @@ from ray._private.runtime_env.container import ContainerManager from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.java_jars import JavaJarsPlugin from ray._private.runtime_env.pip import PipPlugin -from ray._private.runtime_env.plugin import PluginCacheManager +from ray._private.runtime_env.plugin import ( + RuntimeEnvPlugin, + create_for_plugin_if_needed, +) from ray._private.runtime_env.plugin import RuntimeEnvPluginManager from ray._private.runtime_env.py_modules import PyModulesPlugin -from ray._private.runtime_env.uri_cache import URICache from ray._private.runtime_env.working_dir import WorkingDirPlugin from ray.core.generated import ( agent_manager_pb2, @@ -53,12 +54,8 @@ class CreatedEnvResult: creation_time_ms: int -class UriType(Enum): - WORKING_DIR = "working_dir" - PY_MODULES = "py_modules" - PIP = "pip" - CONDA = "conda" - JAVA_JARS = "java_jars" +# e.g., "working_dir" +UriType = str class ReferenceTable: @@ -198,48 +195,37 @@ class RuntimeEnvAgent( # TODO(architkulkarni): "base plugins" and third-party plugins should all go # through the same code path. We should never need to refer to # self._xxx_plugin, we should just iterate through self._plugins. - self._base_plugins = [ + self._base_plugins: List[RuntimeEnvPlugin] = [ self._working_dir_plugin, self._pip_plugin, self._conda_plugin, self._py_modules_plugin, self._java_jars_plugin, ] - self._uri_caches = {} - self._base_plugin_cache_managers = {} + self._plugin_manager = RuntimeEnvPluginManager() for plugin in self._base_plugins: - # Set the max size for the cache. Defaults to 10 GB. - cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper() - cache_size_bytes = int( - (1024 ** 3) * float(os.environ.get(cache_size_env_var, 10)) - ) - self._uri_caches[plugin.name] = URICache( - plugin.delete_uri, cache_size_bytes - ) - self._base_plugin_cache_managers[plugin.name] = PluginCacheManager( - plugin, self._uri_caches[plugin.name] - ) + self._plugin_manager.add_plugin(plugin) self._reference_table = ReferenceTable( self.uris_parser, self.unused_uris_processor, self.unused_runtime_env_processor, ) - self._runtime_env_plugin_manager = RuntimeEnvPluginManager() self._logger = default_logger def uris_parser(self, runtime_env): result = list() - for plugin in self._base_plugins: + for name, plugin_setup_context in self._plugin_manager.plugins.items(): + plugin = plugin_setup_context.class_instance uris = plugin.get_uris(runtime_env) for uri in uris: - result.append((uri, UriType(plugin.name))) + result.append((uri, UriType(name))) return result def unused_uris_processor(self, unused_uris: List[Tuple[str, UriType]]) -> None: for uri, uri_type in unused_uris: - self._uri_caches[uri_type.value].mark_unused(uri) + self._plugin_manager.plugins[str(uri_type)].uri_cache.mark_unused(uri) def unused_runtime_env_processor(self, unused_runtime_env: str) -> None: def delete_runtime_env(): @@ -275,7 +261,9 @@ class RuntimeEnvAgent( ) async def _setup_runtime_env( - runtime_env, serialized_runtime_env, serialized_allocated_resource_instances + runtime_env: RuntimeEnv, + serialized_runtime_env, + serialized_allocated_resource_instances, ): allocated_resource: dict = json.loads( serialized_allocated_resource_instances or "{}" @@ -290,39 +278,25 @@ class RuntimeEnvAgent( runtime_env, context, logger=per_job_logger ) - for manager in self._base_plugin_cache_managers.values(): - await manager.create_if_needed( - runtime_env, context, logger=per_job_logger + # Warn about unrecognized fields in the runtime env. + for name, _ in runtime_env.plugins(): + if name not in self._plugin_manager.plugins: + per_job_logger.warning( + f"runtime_env field {name} is not recognized by " + "Ray and will be ignored. In the future, unrecognized " + "fields in the runtime_env will raise an exception." + ) + + """Run setup for each plugin unless it has already been cached.""" + for ( + plugin_setup_context + ) in self._plugin_manager.sorted_plugin_setup_contexts(): + plugin = plugin_setup_context.class_instance + uri_cache = plugin_setup_context.uri_cache + await create_for_plugin_if_needed( + runtime_env, plugin, uri_cache, context, per_job_logger ) - def setup_plugins(): - # Run setup function from all the plugins - if runtime_env.plugins(): - for ( - setup_context - ) in self._runtime_env_plugin_manager.sorted_plugin_setup_contexts( - runtime_env.plugins() - ): - per_job_logger.debug( - f"Setting up runtime env plugin {setup_context.name}" - ) - # TODO(architkulkarni): implement uri support - setup_context.class_instance.validate(runtime_env) - setup_context.class_instance.create( - "uri not implemented", setup_context.config, context - ) - setup_context.class_instance.modify_context( - "uri not implemented", - setup_context.config, - context, - per_job_logger, - ) - - loop = asyncio.get_event_loop() - # Plugins setup method is sync process, running in other threads - # is to avoid blocking asyncio loop - await loop.run_in_executor(None, setup_plugins) - return context async def _create_runtime_env_with_retry( diff --git a/dashboard/modules/runtime_env/tests/test_runtime_env_agent.py b/dashboard/modules/runtime_env/tests/test_runtime_env_agent.py index 15acdbb62..73de17774 100644 --- a/dashboard/modules/runtime_env/tests/test_runtime_env_agent.py +++ b/dashboard/modules/runtime_env/tests/test_runtime_env_agent.py @@ -12,10 +12,10 @@ def test_reference_table(): def uris_parser(runtime_env) -> Tuple[str, UriType]: result = list() - result.append((runtime_env.working_dir(), UriType.WORKING_DIR)) + result.append((runtime_env.working_dir(), "working_dir")) py_module_uris = runtime_env.py_modules() for uri in py_module_uris: - result.append((uri, UriType.PY_MODULES)) + result.append((uri, "py_modules")) return result def unused_uris_processor(unused_uris: List[Tuple[str, UriType]]) -> None: @@ -57,8 +57,8 @@ def test_reference_table(): ) # Remove runtime env 1 - expected_unused_uris.append(("s3://working_dir_1.zip", UriType.WORKING_DIR)) - expected_unused_uris.append(("s3://py_module_B.zip", UriType.PY_MODULES)) + expected_unused_uris.append(("s3://working_dir_1.zip", "working_dir")) + expected_unused_uris.append(("s3://py_module_B.zip", "py_modules")) expected_unused_runtime_env = runtime_env_1.serialize() reference_table.decrease_reference( runtime_env_1, runtime_env_1.serialize(), "raylet" @@ -67,9 +67,9 @@ def test_reference_table(): assert not expected_unused_runtime_env # Remove runtime env 2 - expected_unused_uris.append(("s3://working_dir_2.zip", UriType.WORKING_DIR)) - expected_unused_uris.append(("s3://py_module_A.zip", UriType.PY_MODULES)) - expected_unused_uris.append(("s3://py_module_C.zip", UriType.PY_MODULES)) + expected_unused_uris.append(("s3://working_dir_2.zip", "working_dir")) + expected_unused_uris.append(("s3://py_module_A.zip", "py_modules")) + expected_unused_uris.append(("s3://py_module_C.zip", "py_modules")) expected_unused_runtime_env = runtime_env_2.serialize() reference_table.decrease_reference( runtime_env_2, runtime_env_2.serialize(), "raylet" diff --git a/python/ray/_private/runtime_env/conda.py b/python/ray/_private/runtime_env/conda.py index 2b13bf311..379157df9 100644 --- a/python/ray/_private/runtime_env/conda.py +++ b/python/ray/_private/runtime_env/conda.py @@ -315,6 +315,13 @@ class CondaPlugin(RuntimeEnvPlugin): context: RuntimeEnvContext, logger: logging.Logger = default_logger, ) -> int: + if uri is None: + # The "conda" field is the name of an existing conda env, so no + # need to create one. + # TODO(architkulkarni): Try "conda activate" here to see if the + # env exists, and raise an exception if it doesn't. + return 0 + # Currently create method is still a sync process, to avoid blocking # the loop, need to run this function in another thread. # TODO(Catch-Bull): Refactor method create into an async process, and diff --git a/python/ray/_private/runtime_env/pip.py b/python/ray/_private/runtime_env/pip.py index 19e6ba3bc..12c2ee7f5 100644 --- a/python/ray/_private/runtime_env/pip.py +++ b/python/ray/_private/runtime_env/pip.py @@ -457,7 +457,7 @@ class PipPlugin(RuntimeEnvPlugin): uris: List[str], runtime_env: "RuntimeEnv", # noqa: F821 context: RuntimeEnvContext, - logger: Optional[logging.Logger] = default_logger, + logger: logging.Logger = default_logger, ): if not runtime_env.has_pip(): return diff --git a/python/ray/_private/runtime_env/plugin.py b/python/ray/_private/runtime_env/plugin.py index 562bd2b8e..4326a53b4 100644 --- a/python/ray/_private/runtime_env/plugin.py +++ b/python/ray/_private/runtime_env/plugin.py @@ -2,7 +2,7 @@ import logging import os import json from abc import ABC -from typing import List, Dict, Tuple, Any +from typing import List, Dict, Optional, Any, Type from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.uri_cache import URICache @@ -28,28 +28,25 @@ class RuntimeEnvPlugin(ABC): priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY @staticmethod - def validate(runtime_env_dict: dict) -> str: - """Validate user entry and returns a URI uniquely describing resource. - - This method will be called at ``f.options(runtime_env=...)`` or - ``ray.init(runtime_env=...)`` time and it should check the runtime env - dictionary for any errors. For example, it can raise "TypeError: - expected string for "conda" field". + def validate(runtime_env_dict: dict) -> None: + """Validate user entry for this plugin. Args: - runtime_env_dict(dict): the entire dictionary passed in by user. - - Returns: - uri(str): a URI uniquely describing this resource (e.g., a hash of - the conda spec). + runtime_env_dict: the user-supplied runtime environment dict. + Raises: + ValueError: if the validation fails. """ - raise NotImplementedError() + pass def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821 - return None + return [] - def create( - self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821 + async def create( + self, + uri: Optional[str], + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger, ) -> float: """Create and install the runtime environment. @@ -57,9 +54,10 @@ class RuntimeEnvPlugin(ABC): used as a caching mechanism. Args: - uri(str): a URI uniquely describing this resource. - runtime_env(RuntimeEnv): the runtime env protobuf. - ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + uri: A URI uniquely describing this resource. + runtime_env: The RuntimeEnv object. + context: auxiliary information supplied by Ray. + logger: A logger to log messages during the context modification. Returns: the disk space taken up by this plugin installation for this @@ -81,9 +79,10 @@ class RuntimeEnvPlugin(ABC): startup, or add new environment variables. Args: - uris(List[str]): a URIs used by this resource. - runtime_env(RuntimeEnv): the runtime env protobuf. - ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + uris: The URIs used by this resource. + runtime_env: The RuntimeEnv object. + context: Auxiliary information supplied by Ray. + logger: A logger to log messages during the context modification. """ return @@ -91,8 +90,7 @@ class RuntimeEnvPlugin(ABC): """Delete the the runtime environment given uri. Args: - uri(str): a URI uniquely describing this resource. - ctx(RuntimeEnvContext): auxiliary information supplied by Ray. + uri: a URI uniquely describing this resource. Returns: the amount of space reclaimed by the deletion. @@ -101,29 +99,59 @@ class RuntimeEnvPlugin(ABC): class PluginSetupContext: - def __init__(self, name: str, config: Any, class_instance: object): + def __init__( + self, + name: str, + class_instance: RuntimeEnvPlugin, + priority: int, + uri_cache: URICache, + ): self.name = name - self.config = config self.class_instance = class_instance + self.priority = priority + self.uri_cache = uri_cache class RuntimeEnvPluginManager: """This manager is used to load plugins in runtime env agent.""" - class Context: - def __init__(self, class_instance, priority): - self.class_instance = class_instance - self.priority = priority - def __init__(self): - self.plugins: Dict[str, RuntimeEnvPluginManager.Context] = {} + self.plugins: Dict[str, PluginSetupContext] = {} plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR) if plugin_config_str: plugin_configs = json.loads(plugin_config_str) self.load_plugins(plugin_configs) - def load_plugins(self, plugin_configs: List[Dict]): - """Load runtime env plugins""" + def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None: + if not issubclass(plugin_class, RuntimeEnvPlugin): + raise RuntimeError( + f"Invalid runtime env plugin class {plugin_class}. " + "The plugin class must inherit " + "ray._private.runtime_env.plugin.RuntimeEnvPlugin." + ) + if not plugin_class.name: + raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.") + if plugin_class.name in self.plugins: + raise RuntimeError( + f"The name of runtime env plugin {plugin_class} conflicts " + f"with {self.plugins[plugin_class.name]}.", + ) + + def validate_priority(self, priority: Any) -> None: + if ( + not isinstance(priority, int) + or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY + or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY + ): + raise RuntimeError( + f"Invalid runtime env priority {priority}, " + "it should be an integer between " + f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} " + f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}." + ) + + def load_plugins(self, plugin_configs: List[Dict]) -> None: + """Load runtime env plugins and create URI caches for them.""" for plugin_config in plugin_configs: if ( not isinstance(plugin_config, dict) @@ -135,21 +163,7 @@ class RuntimeEnvPluginManager: f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field." ) plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME]) - if not issubclass(plugin_class, RuntimeEnvPlugin): - raise RuntimeError( - f"Invalid runtime env plugin class {plugin_class}. " - "The plugin class must inherit " - "ray._private.runtime_env.plugin.RuntimeEnvPlugin." - ) - if not plugin_class.name: - raise RuntimeError( - f"No valid name in runtime env plugin {plugin_class}." - ) - if plugin_class.name in self.plugins: - raise RuntimeError( - f"The name of runtime env plugin {plugin_class} conflicts " - f"with {self.plugins[plugin_class.name]}.", - ) + self.validate_plugin_class(plugin_class) # The priority should be an integer between 0 and 100. # The default priority is 10. A smaller number indicates a @@ -158,73 +172,86 @@ class RuntimeEnvPluginManager: priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME] else: priority = plugin_class.priority - if ( - not isinstance(priority, int) - or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY - or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY - ): - raise RuntimeError( - f"Invalid runtime env priority {priority}, " - "it should be an integer between " - f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} " - f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}." - ) + self.validate_priority(priority) - self.plugins[plugin_class.name] = RuntimeEnvPluginManager.Context( - plugin_class(), priority + class_instance = plugin_class() + self.plugins[plugin_class.name] = PluginSetupContext( + plugin_class.name, + class_instance, + priority, + self.create_uri_cache_for_plugin(class_instance), ) - def sorted_plugin_setup_contexts( - self, inputs: List[Tuple[str, Any]] - ) -> List[PluginSetupContext]: - used_plugins = [] - for name, config in inputs: - if name not in self.plugins: - default_logger.error( - f"runtime_env field {name} is not recognized by " - "Ray and will be ignored. In the future, unrecognized " - "fields in the runtime_env will raise an exception." - ) - continue - used_plugins.append( - ( - name, - config, - self.plugins[name].class_instance, - self.plugins[name].priority, - ) - ) - sort_used_plugins = sorted(used_plugins, key=lambda x: x[3], reverse=False) - return [ - PluginSetupContext(name, config, class_instance) - for name, config, class_instance, _ in sort_used_plugins - ] + def add_plugin(self, plugin: RuntimeEnvPlugin) -> None: + """Add a plugin to the manager and create a URI cache for it. + + Args: + plugin: The class instance of the plugin. + """ + plugin_class = type(plugin) + self.validate_plugin_class(plugin_class) + self.validate_priority(plugin_class.priority) + self.plugins[plugin_class.name] = PluginSetupContext( + plugin_class.name, + plugin, + plugin_class.priority, + self.create_uri_cache_for_plugin(plugin), + ) + + def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache: + """Create a URI cache for a plugin. + + Args: + plugin_name: The name of the plugin. + + Returns: + The created URI cache for the plugin. + """ + # Set the max size for the cache. Defaults to 10 GB. + cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper() + cache_size_bytes = int( + (1024 ** 3) * float(os.environ.get(cache_size_env_var, 10)) + ) + return URICache(plugin.delete_uri, cache_size_bytes) + + def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]: + """Get the sorted plugin setup contexts, sorted by increasing priority. + + Returns: + The sorted plugin setup contexts. + """ + return sorted(self.plugins.values(), key=lambda x: x.priority) -@DeveloperAPI -class PluginCacheManager: - """Manages a plugin and a cache for its local resources.""" +async def create_for_plugin_if_needed( + runtime_env, + plugin: RuntimeEnvPlugin, + uri_cache: URICache, + context: RuntimeEnvContext, + logger: logging.Logger = default_logger, +): + """Set up the environment using the plugin if not already set up and cached.""" + if plugin.name not in runtime_env or runtime_env[plugin.name] is None: + return - def __init__(self, plugin: RuntimeEnvPlugin, uri_cache: URICache): - self._plugin = plugin - self._uri_cache = uri_cache + plugin.validate(runtime_env) - async def create_if_needed( - self, - runtime_env: "RuntimeEnv", # noqa: F821 - context: RuntimeEnvContext, - logger: logging.Logger = default_logger, - ): - uris = self._plugin.get_uris(runtime_env) - for uri in uris: - if uri not in self._uri_cache: - logger.debug(f"Cache miss for URI {uri}.") - size_bytes = await self._plugin.create( - uri, runtime_env, context, logger=logger - ) - self._uri_cache.add(uri, size_bytes, logger=logger) - else: - logger.debug(f"Cache hit for URI {uri}.") - self._uri_cache.mark_used(uri, logger=logger) + uris = plugin.get_uris(runtime_env) - self._plugin.modify_context(uris, runtime_env, context) + if not uris: + logger.debug( + f"No URIs for runtime env plugin {plugin.name}; " + "create always without checking the cache." + ) + await plugin.create(None, runtime_env, context, logger=logger) + + for uri in uris: + if uri not in uri_cache: + logger.debug(f"Cache miss for URI {uri}.") + size_bytes = await plugin.create(uri, runtime_env, context, logger=logger) + uri_cache.add(uri, size_bytes, logger=logger) + else: + logger.debug(f"Cache hit for URI {uri}.") + uri_cache.mark_used(uri, logger=logger) + + plugin.modify_context(uris, runtime_env, context, logger) diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index 12c9ea90f..cdf0a0067 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -134,10 +134,10 @@ class WorkingDirPlugin(RuntimeEnvPlugin): async def create( self, - uri: str, + uri: Optional[str], runtime_env: dict, context: RuntimeEnvContext, - logger: Optional[logging.Logger] = default_logger, + logger: logging.Logger = default_logger, ) -> int: local_dir = await download_and_unpack_package( uri, self._resources_dir, self._gcs_aio_client, logger=logger @@ -145,7 +145,11 @@ class WorkingDirPlugin(RuntimeEnvPlugin): return get_directory_size_bytes(local_dir) def modify_context( - self, uris: List[str], runtime_env_dict: Dict, context: RuntimeEnvContext + self, + uris: List[str], + runtime_env_dict: Dict, + context: RuntimeEnvContext, + logger: Optional[logging.Logger] = default_logger, ): if not uris: return diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 59671c3b9..cf026af74 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1445,12 +1445,10 @@ def get_runtime_env_info( proto_runtime_env_info = ProtoRuntimeEnvInfo() - if runtime_env.get_working_dir_uri(): - proto_runtime_env_info.uris.working_dir_uri = runtime_env.get_working_dir_uri() - if len(runtime_env.get_py_modules_uris()) > 0: - proto_runtime_env_info.uris.py_modules_uris[ - : - ] = runtime_env.get_py_modules_uris() + if runtime_env.working_dir_uri(): + proto_runtime_env_info.uris.working_dir_uri = runtime_env.working_dir_uri() + if len(runtime_env.py_modules_uris()) > 0: + proto_runtime_env_info.uris.py_modules_uris[:] = runtime_env.py_modules_uris() # TODO(Catch-Bull): overload `__setitem__` for `RuntimeEnv`, change the # runtime_env of all internal code from dict to RuntimeEnv. diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index a0e94c043..e00cd9311 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2044,7 +2044,7 @@ def connect( ) # In client mode, if we use runtime envs with "working_dir", then # it'll be handled automatically. Otherwise, add the current dir. - if not job_config.client_job and not job_config.runtime_env_has_uris(): + if not job_config.client_job and not job_config.runtime_env_has_working_dir(): current_directory = os.path.abspath(os.path.curdir) worker.run_function_on_all_workers( lambda worker_info: sys.path.insert(1, current_directory) diff --git a/python/ray/job_config.py b/python/ray/job_config.py index a412a2e97..2413c5ffa 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -126,9 +126,8 @@ class JobConfig: return self._cached_pb - def runtime_env_has_uris(self): - """Whether there are uris in runtime env or not""" - return self._validate_runtime_env().has_uris() + def runtime_env_has_working_dir(self): + return self._validate_runtime_env().has_working_dir() def get_serialized_runtime_env(self) -> str: """Return the JSON-serialized parsed runtime env dict""" diff --git a/python/ray/runtime_env/runtime_env.py b/python/ray/runtime_env/runtime_env.py index c09353fe6..d802b50a2 100644 --- a/python/ray/runtime_env/runtime_env.py +++ b/python/ray/runtime_env/runtime_env.py @@ -345,30 +345,6 @@ class RuntimeEnv(dict): if all(val is None for val in self.values()): self.clear() - def get_uris(self) -> List[str]: - # TODO(architkulkarni): this should programmatically be extended with - # URIs from all plugins. - plugin_uris = [] - if "working_dir" in self: - plugin_uris.append(self["working_dir"]) - if "py_modules" in self: - for uri in self["py_modules"]: - plugin_uris.append(uri) - if "conda" in self: - uri = get_conda_uri(self) - if uri is not None: - plugin_uris.append(uri) - if "pip" in self: - uri = get_pip_uri(self) - if uri is not None: - plugin_uris.append(uri) - - def get_working_dir_uri(self) -> str: - return self.get("working_dir", None) - - def get_py_modules_uris(self) -> List[str]: - return self.get("py_modules", []) - def __setitem__(self, key: str, value: Any) -> None: if is_dataclass(value): jsonable_type = asdict(value) @@ -416,16 +392,8 @@ class RuntimeEnv(dict): return runtime_env_dict - def has_uris(self) -> bool: - if ( - self.working_dir_uri() - or self.py_modules_uris() - or self.conda_uri() - or self.pip_uri() - or self.plugin_uris() - ): - return True - return False + def has_working_dir(self) -> bool: + return self.get("working_dir") is not None def working_dir_uri(self) -> Optional[str]: return self.get("working_dir") diff --git a/python/ray/tests/test_runtime_env_plugin.py b/python/ray/tests/test_runtime_env_plugin.py index f59634b47..045ed1a3c 100644 --- a/python/ray/tests/test_runtime_env_plugin.py +++ b/python/ray/tests/test_runtime_env_plugin.py @@ -1,9 +1,12 @@ +import asyncio import logging import os from pathlib import Path +import time +from unittest import mock + import tempfile import json -from time import sleep from typing import List import pytest @@ -13,6 +16,7 @@ from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.plugin import RuntimeEnvPlugin from ray._private.test_utils import enable_external_redis, wait_for_condition from ray.exceptions import RuntimeEnvSetupError +from ray.runtime_env.runtime_env import RuntimeEnv MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin" MY_PLUGIN_NAME = "MyPlugin" @@ -23,8 +27,8 @@ class MyPlugin(RuntimeEnvPlugin): env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY" @staticmethod - def validate(runtime_env_dict: dict) -> str: - value = runtime_env_dict[MY_PLUGIN_NAME] + def validate(runtime_env: RuntimeEnv) -> str: + value = runtime_env[MY_PLUGIN_NAME] if value == "fail": raise ValueError("not allowed") return value @@ -32,10 +36,11 @@ class MyPlugin(RuntimeEnvPlugin): def modify_context( self, uris: List[str], - plugin_config_dict: dict, + runtime_env: RuntimeEnv, ctx: RuntimeEnvContext, logger: logging.Logger, ) -> None: + plugin_config_dict = runtime_env[MY_PLUGIN_NAME] ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"]) ctx.command_prefix.append( f"echo {plugin_config_dict['tmp_content']} > " @@ -103,14 +108,20 @@ class MyPluginForHang(RuntimeEnvPlugin): def validate(runtime_env_dict: dict) -> str: return "True" - def create(self, uri: str, runtime_env: dict, ctx: RuntimeEnvContext) -> float: + async def create( + self, + uri: str, + runtime_env: dict, + ctx: RuntimeEnvContext, + logger: logging.Logger, + ) -> float: global my_plugin_setup_times my_plugin_setup_times += 1 # first setup if my_plugin_setup_times == 1: # sleep forever - sleep(3600) + await asyncio.sleep(3600) def modify_context( self, @@ -164,10 +175,6 @@ DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin" DUMMY_PLUGIN_NAME = "DummyPlugin" HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin" HANG_PLUGIN_NAME = "HangPlugin" -DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = ( - "ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin" -) -DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout" class DummyPlugin(RuntimeEnvPlugin): @@ -181,27 +188,21 @@ class DummyPlugin(RuntimeEnvPlugin): class HangPlugin(DummyPlugin): name = HANG_PLUGIN_NAME - def create( - self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821 + async def create( + self, + uri: str, + runtime_env: "RuntimeEnv", + ctx: RuntimeEnvContext, + logger: logging.Logger, # noqa: F821 ) -> float: - sleep(3600) - - -class DiasbleTimeoutPlugin(DummyPlugin): - name = DISABLE_TIMEOUT_PLUGIN_NAME - - def create( - self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821 - ) -> float: - sleep(10) + await asyncio.sleep(3600) @pytest.mark.parametrize( "set_runtime_env_plugins", [ '[{"class":"' + DUMMY_PLUGIN_CLASS_PATH + '"},' - '{"class":"' + HANG_PLUGIN_CLASS_PATH + '"},' - '{"class":"' + DISABLE_TIMEOUT_PLUGIN_CLASS_PATH + '"}]', + '{"class":"' + HANG_PLUGIN_CLASS_PATH + '"}]', ], indirect=True, ) @@ -215,7 +216,7 @@ def test_plugin_timeout(set_runtime_env_plugins, start_cluster): f.options( runtime_env={ HANG_PLUGIN_NAME: {"name": "f1"}, - "config": {"setup_timeout_seconds": 10}, + "config": {"setup_timeout_seconds": 1}, } ).remote(), f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(), @@ -397,6 +398,225 @@ def test_unexpected_field_warning(shutdown_only): wait_for_condition(lambda: "unexpected_field is not recognized" in f.read()) +URI_CACHING_TEST_PLUGIN_CLASS_PATH = ( + "ray.tests.test_runtime_env_plugin.UriCachingTestPlugin" +) +URI_CACHING_TEST_PLUGIN_NAME = "UriCachingTestPlugin" +URI_CACHING_TEST_DIR = Path(tempfile.gettempdir()) / "runtime_env_uri_caching_test" +uri_caching_test_file_path = URI_CACHING_TEST_DIR / "uri_caching_test_file.json" +URI_CACHING_TEST_DIR.mkdir(parents=True, exist_ok=True) +uri_caching_test_file_path.write_text("{}") + + +def get_plugin_usage_data(): + with open(uri_caching_test_file_path, "r") as f: + data = json.loads(f.read()) + return data + + +class UriCachingTestPlugin(RuntimeEnvPlugin): + """A plugin that fakes taking up local disk space when creating its environment. + + This plugin is used to test that the URI caching is working correctly. + Example: + runtime_env = {"UriCachingTestPlugin": {"uri": "file:///a", "size_bytes": 10}} + """ + + name = URI_CACHING_TEST_PLUGIN_NAME + + def __init__(self): + # Keeps track of the "disk space" each URI takes up for the + # UriCachingTestPlugin. + self.uris_to_sizes = {} + self.modify_context_call_count = 0 + self.create_call_count = 0 + + def write_plugin_usage_data(self) -> None: + with open(uri_caching_test_file_path, "w") as f: + data = { + "uris_to_sizes": self.uris_to_sizes, + "modify_context_call_count": self.modify_context_call_count, + "create_call_count": self.create_call_count, + } + f.write(json.dumps(data)) + + def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F811 + return [runtime_env[self.name]["uri"]] + + async def create( + self, + uri, + runtime_env: "RuntimeEnv", # noqa: F821 + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> float: + self.create_call_count += 1 + created_size_bytes = runtime_env[self.name]["size_bytes"] + self.uris_to_sizes[uri] = created_size_bytes + self.write_plugin_usage_data() + return created_size_bytes + + def modify_context( + self, + uris: List[str], + runtime_env: "RuntimeEnv", + context: RuntimeEnvContext, + logger: logging.Logger, + ) -> None: + self.modify_context_call_count += 1 + self.write_plugin_usage_data() + + def delete_uri(self, uri: str, logger: logging.Logger) -> int: + size = self.uris_to_sizes.pop(uri) + self.write_plugin_usage_data() + return size + + +# Set scope to "class" to force this to run before start_cluster, whose scope +# is "function". We need these env vars to be set before Ray is started. +@pytest.fixture(scope="class") +def uri_cache_size_100_gb(): + var = f"RAY_RUNTIME_ENV_{URI_CACHING_TEST_PLUGIN_NAME}_CACHE_SIZE_GB".upper() + with mock.patch.dict( + os.environ, + { + var: "100", + }, + ): + print("Set URI cache size for UriCachingTestPlugin to 100 GB") + yield + + +def gb_to_bytes(size_gb: int) -> int: + return size_gb * 1024 * 1024 * 1024 + + +class TestGC: + @pytest.mark.parametrize( + "set_runtime_env_plugins", + [ + json.dumps([{"class": URI_CACHING_TEST_PLUGIN_CLASS_PATH}]), + ], + indirect=True, + ) + def test_uri_caching( + self, set_runtime_env_plugins, start_cluster, uri_cache_size_100_gb + ): + cluster, address = start_cluster + + ray.init(address=address) + + def reinit(): + ray.shutdown() + # TODO(architkulkarni): Currently, reinit the driver will generate the same + # job id. And if we reinit immediately after shutdown, raylet may + # process new job started before old job finished in some cases. This + # inconsistency could disorder the URI reference and delete a valid + # runtime env. We sleep here to walk around this issue. + time.sleep(5) + ray.init(address=address) + + @ray.remote + def f(): + return True + + # Run a task to trigger runtime_env creation. + ref1 = f.options( + runtime_env={ + URI_CACHING_TEST_PLUGIN_NAME: { + "uri": "file:///tmp/test_uri_1", + "size_bytes": gb_to_bytes(50), + } + } + ).remote() + ray.get(ref1) + # Check that the URI was "created on disk". + print(get_plugin_usage_data()) + wait_for_condition( + lambda: get_plugin_usage_data() + == { + "uris_to_sizes": {"file:///tmp/test_uri_1": gb_to_bytes(50)}, + "modify_context_call_count": 1, + "create_call_count": 1, + } + ) + + # Shutdown ray to stop the worker and remove the runtime_env reference. + reinit() + + # Run a task with a different runtime env. + ref2 = f.options( + runtime_env={ + URI_CACHING_TEST_PLUGIN_NAME: { + "uri": "file:///tmp/test_uri_2", + "size_bytes": gb_to_bytes(51), + } + } + ).remote() + ray.get(ref2) + # This should delete the old URI and create a new one, because 50 + 51 > 100 + # and the cache size limit is 100. + wait_for_condition( + lambda: get_plugin_usage_data() + == { + "uris_to_sizes": {"file:///tmp/test_uri_2": gb_to_bytes(51)}, + "modify_context_call_count": 2, + "create_call_count": 2, + } + ) + + reinit() + + # Run a task with the cached runtime env, to check that the runtime env is not + # created anew. + ref3 = f.options( + runtime_env={ + URI_CACHING_TEST_PLUGIN_NAME: { + "uri": "file:///tmp/test_uri_2", + "size_bytes": gb_to_bytes(51), + } + } + ).remote() + ray.get(ref3) + # modify_context should still be called even if create() is not called. + # Example: for a "conda" plugin, even if the conda env is already created + # and cached, we still need to call modify_context to add "conda activate" to + # the RuntimeEnvContext.command_prefix for the worker. + wait_for_condition( + lambda: get_plugin_usage_data() + == { + "uris_to_sizes": {"file:///tmp/test_uri_2": gb_to_bytes(51)}, + "modify_context_call_count": 3, + "create_call_count": 2, + } + ) + + reinit() + + # Run a task with a new runtime env + ref4 = f.options( + runtime_env={ + URI_CACHING_TEST_PLUGIN_NAME: { + "uri": "file:///tmp/test_uri_3", + "size_bytes": gb_to_bytes(10), + } + } + ).remote() + ray.get(ref4) + # The last two URIs should still be present in the cache, because 51 + 10 < 100. + wait_for_condition( + lambda: get_plugin_usage_data() + == { + "uris_to_sizes": { + "file:///tmp/test_uri_2": gb_to_bytes(51), + "file:///tmp/test_uri_3": gb_to_bytes(10), + }, + "modify_context_call_count": 4, + "create_call_count": 3, + } + ) + + if __name__ == "__main__": import sys