mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[runtime env] plugin refactor[1/n] (#26077)
This commit is contained in:
parent
876fef0fcf
commit
58bfad84d3
8 changed files with 178 additions and 110 deletions
|
@ -20,10 +20,10 @@ from ray._private.runtime_env.context import RuntimeEnvContext
|
||||||
from ray._private.runtime_env.java_jars import JavaJarsPlugin
|
from ray._private.runtime_env.java_jars import JavaJarsPlugin
|
||||||
from ray._private.runtime_env.pip import PipPlugin
|
from ray._private.runtime_env.pip import PipPlugin
|
||||||
from ray._private.runtime_env.plugin import PluginCacheManager
|
from ray._private.runtime_env.plugin import PluginCacheManager
|
||||||
|
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
|
||||||
from ray._private.runtime_env.py_modules import PyModulesPlugin
|
from ray._private.runtime_env.py_modules import PyModulesPlugin
|
||||||
from ray._private.runtime_env.uri_cache import URICache
|
from ray._private.runtime_env.uri_cache import URICache
|
||||||
from ray._private.runtime_env.working_dir import WorkingDirPlugin
|
from ray._private.runtime_env.working_dir import WorkingDirPlugin
|
||||||
from ray._private.utils import import_attr
|
|
||||||
from ray.core.generated import (
|
from ray.core.generated import (
|
||||||
agent_manager_pb2,
|
agent_manager_pb2,
|
||||||
runtime_env_agent_pb2,
|
runtime_env_agent_pb2,
|
||||||
|
@ -224,6 +224,7 @@ class RuntimeEnvAgent(
|
||||||
self.unused_uris_processor,
|
self.unused_uris_processor,
|
||||||
self.unused_runtime_env_processor,
|
self.unused_runtime_env_processor,
|
||||||
)
|
)
|
||||||
|
self._runtime_env_plugin_manager = RuntimeEnvPluginManager()
|
||||||
|
|
||||||
self._logger = default_logger
|
self._logger = default_logger
|
||||||
|
|
||||||
|
@ -295,13 +296,13 @@ class RuntimeEnvAgent(
|
||||||
|
|
||||||
def setup_plugins():
|
def setup_plugins():
|
||||||
# Run setup function from all the plugins
|
# Run setup function from all the plugins
|
||||||
for plugin_class_path, config in runtime_env.plugins():
|
for name, config in runtime_env.plugins():
|
||||||
per_job_logger.debug(
|
per_job_logger.debug(f"Setting up runtime env plugin {name}")
|
||||||
f"Setting up runtime env plugin {plugin_class_path}"
|
plugin = self._runtime_env_plugin_manager.get_plugin(name)
|
||||||
)
|
if plugin is None:
|
||||||
plugin_class = import_attr(plugin_class_path)
|
raise RuntimeError(f"runtime env plugin {name} not found.")
|
||||||
plugin = plugin_class()
|
|
||||||
# TODO(architkulkarni): implement uri support
|
# TODO(architkulkarni): implement uri support
|
||||||
|
plugin.validate(runtime_env)
|
||||||
plugin.create("uri not implemented", json.loads(config), context)
|
plugin.create("uri not implemented", json.loads(config), context)
|
||||||
plugin.modify_context(
|
plugin.modify_context(
|
||||||
"uri not implemented",
|
"uri not implemented",
|
||||||
|
|
|
@ -1,2 +1,5 @@
|
||||||
# Env var set by job manager to pass runtime env and metadata to subprocess
|
# Env var set by job manager to pass runtime env and metadata to subprocess
|
||||||
RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"
|
RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"
|
||||||
|
|
||||||
|
# The plugins which should be loaded when ray cluster starts.
|
||||||
|
RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS"
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from ray._private.runtime_env.context import RuntimeEnvContext
|
from ray._private.runtime_env.context import RuntimeEnvContext
|
||||||
from ray._private.runtime_env.uri_cache import URICache
|
from ray._private.runtime_env.uri_cache import URICache
|
||||||
|
from ray._private.runtime_env.constants import RAY_RUNTIME_ENV_PLUGINS_ENV_VAR
|
||||||
from ray.util.annotations import DeveloperAPI
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
from ray._private.utils import import_attr
|
||||||
|
|
||||||
default_logger = logging.getLogger(__name__)
|
default_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -13,7 +16,7 @@ default_logger = logging.getLogger(__name__)
|
||||||
class RuntimeEnvPlugin(ABC):
|
class RuntimeEnvPlugin(ABC):
|
||||||
"""Abstract base class for runtime environment plugins."""
|
"""Abstract base class for runtime environment plugins."""
|
||||||
|
|
||||||
name: str
|
name: str = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(runtime_env_dict: dict) -> str:
|
def validate(runtime_env_dict: dict) -> str:
|
||||||
|
@ -88,6 +91,45 @@ class RuntimeEnvPlugin(ABC):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeEnvPluginManager:
|
||||||
|
"""This mananger is used to load plugins in runtime env agent."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.plugins = {}
|
||||||
|
plugins_config = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
|
||||||
|
if plugins_config:
|
||||||
|
self.load_plugins(plugins_config.split(","))
|
||||||
|
|
||||||
|
def load_plugins(self, plugin_classes: List[str]):
|
||||||
|
"""Load runtime env plugins"""
|
||||||
|
for plugin_class_path in plugin_classes:
|
||||||
|
plugin_class = import_attr(plugin_class_path)
|
||||||
|
if not issubclass(plugin_class, RuntimeEnvPlugin):
|
||||||
|
default_logger.warning(
|
||||||
|
"Invalid runtime env plugin class %s. "
|
||||||
|
"The plugin class must inherit "
|
||||||
|
"ray._private.runtime_env.plugin.RuntimeEnvPlugin.",
|
||||||
|
plugin_class,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if not plugin_class.name:
|
||||||
|
default_logger.warning(
|
||||||
|
"No valid name in runtime env plugin %s", plugin_class
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if plugin_class.name in self.plugins:
|
||||||
|
default_logger.warning(
|
||||||
|
"The name of runtime env plugin %s conflicts with %s",
|
||||||
|
plugin_class,
|
||||||
|
self.plugins[plugin_class.name],
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
self.plugins[plugin_class.name] = plugin_class()
|
||||||
|
|
||||||
|
def get_plugin(self, name: str):
|
||||||
|
return self.plugins.get(name)
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
class PluginCacheManager:
|
class PluginCacheManager:
|
||||||
"""Manages a plugin and a cache for its local resources."""
|
"""Manages a plugin and a cache for its local resources."""
|
||||||
|
|
|
@ -10,9 +10,7 @@ import ray
|
||||||
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
|
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
|
||||||
from ray._private.runtime_env.conda import get_uri as get_conda_uri
|
from ray._private.runtime_env.conda import get_uri as get_conda_uri
|
||||||
from ray._private.runtime_env.pip import get_uri as get_pip_uri
|
from ray._private.runtime_env.pip import get_uri as get_pip_uri
|
||||||
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
|
|
||||||
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
|
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
|
||||||
from ray._private.utils import import_attr
|
|
||||||
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
|
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
|
||||||
from ray.core.generated.runtime_env_common_pb2 import (
|
from ray.core.generated.runtime_env_common_pb2 import (
|
||||||
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
|
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
|
||||||
|
@ -85,9 +83,8 @@ def _parse_proto_plugin_runtime_env(
|
||||||
):
|
):
|
||||||
"""Parse plugin runtime env protobuf to runtime env dict."""
|
"""Parse plugin runtime env protobuf to runtime env dict."""
|
||||||
if runtime_env.python_runtime_env.HasField("plugin_runtime_env"):
|
if runtime_env.python_runtime_env.HasField("plugin_runtime_env"):
|
||||||
runtime_env_dict["plugins"] = dict()
|
|
||||||
for plugin in runtime_env.python_runtime_env.plugin_runtime_env.plugins:
|
for plugin in runtime_env.python_runtime_env.plugin_runtime_env.plugins:
|
||||||
runtime_env_dict["plugins"][plugin.class_path] = json.loads(plugin.config)
|
runtime_env_dict[plugin.class_path] = json.loads(plugin.config)
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
|
@ -319,8 +316,12 @@ class RuntimeEnv(dict):
|
||||||
"_ray_release",
|
"_ray_release",
|
||||||
"_ray_commit",
|
"_ray_commit",
|
||||||
"_inject_current_ray",
|
"_inject_current_ray",
|
||||||
"plugins",
|
|
||||||
"config",
|
"config",
|
||||||
|
# TODO(SongGuyang): We add this because the test
|
||||||
|
# `test_experimental_package_github` set a `docker`
|
||||||
|
# field which is not supported. We should remove it
|
||||||
|
# with the test.
|
||||||
|
"docker",
|
||||||
}
|
}
|
||||||
|
|
||||||
extensions_fields: Set[str] = {
|
extensions_fields: Set[str] = {
|
||||||
|
@ -363,19 +364,20 @@ class RuntimeEnv(dict):
|
||||||
if runtime_env.get("java_jars"):
|
if runtime_env.get("java_jars"):
|
||||||
runtime_env["java_jars"] = runtime_env.get("java_jars")
|
runtime_env["java_jars"] = runtime_env.get("java_jars")
|
||||||
|
|
||||||
|
self.update(runtime_env)
|
||||||
|
|
||||||
# Blindly trust that the runtime_env has already been validated.
|
# Blindly trust that the runtime_env has already been validated.
|
||||||
# This is dangerous and should only be used internally (e.g., on the
|
# This is dangerous and should only be used internally (e.g., on the
|
||||||
# deserialization codepath.
|
# deserialization codepath.
|
||||||
if not _validate:
|
if not _validate:
|
||||||
self.update(runtime_env)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if runtime_env.get("conda") and runtime_env.get("pip"):
|
if self.get("conda") and self.get("pip"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The 'pip' field and 'conda' field of "
|
"The 'pip' field and 'conda' field of "
|
||||||
"runtime_env cannot both be specified.\n"
|
"runtime_env cannot both be specified.\n"
|
||||||
f"specified pip field: {runtime_env['pip']}\n"
|
f"specified pip field: {self['pip']}\n"
|
||||||
f"specified conda field: {runtime_env['conda']}\n"
|
f"specified conda field: {self['conda']}\n"
|
||||||
"To use pip with conda, please only set the 'conda' "
|
"To use pip with conda, please only set the 'conda' "
|
||||||
"field, and specify your pip dependencies "
|
"field, and specify your pip dependencies "
|
||||||
"within the conda YAML config dict: see "
|
"within the conda YAML config dict: see "
|
||||||
|
@ -385,48 +387,21 @@ class RuntimeEnv(dict):
|
||||||
)
|
)
|
||||||
|
|
||||||
for option, validate_fn in OPTION_TO_VALIDATION_FN.items():
|
for option, validate_fn in OPTION_TO_VALIDATION_FN.items():
|
||||||
option_val = runtime_env.get(option)
|
option_val = self.get(option)
|
||||||
if option_val is not None:
|
if option_val is not None:
|
||||||
|
del self[option]
|
||||||
self[option] = option_val
|
self[option] = option_val
|
||||||
|
|
||||||
if "_ray_release" in runtime_env:
|
if "_ray_commit" not in self:
|
||||||
self["_ray_release"] = runtime_env["_ray_release"]
|
|
||||||
|
|
||||||
if "_ray_commit" in runtime_env:
|
|
||||||
self["_ray_commit"] = runtime_env["_ray_commit"]
|
|
||||||
else:
|
|
||||||
if self.get("pip") or self.get("conda"):
|
if self.get("pip") or self.get("conda"):
|
||||||
self["_ray_commit"] = ray.__commit__
|
self["_ray_commit"] = ray.__commit__
|
||||||
|
|
||||||
# Used for testing wheels that have not yet been merged into master.
|
# Used for testing wheels that have not yet been merged into master.
|
||||||
# If this is set to True, then we do not inject Ray into the conda
|
# If this is set to True, then we do not inject Ray into the conda
|
||||||
# or pip dependencies.
|
# or pip dependencies.
|
||||||
if "_inject_current_ray" in runtime_env:
|
if "_inject_current_ray" not in self:
|
||||||
self["_inject_current_ray"] = runtime_env["_inject_current_ray"]
|
if "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
|
||||||
elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
|
self["_inject_current_ray"] = True
|
||||||
self["_inject_current_ray"] = True
|
|
||||||
if "plugins" in runtime_env:
|
|
||||||
self["plugins"] = dict()
|
|
||||||
for class_path, plugin_field in runtime_env["plugins"].items():
|
|
||||||
plugin_class: RuntimeEnvPlugin = import_attr(class_path)
|
|
||||||
if not issubclass(plugin_class, RuntimeEnvPlugin):
|
|
||||||
# TODO(simon): move the inferface to public once ready.
|
|
||||||
raise TypeError(
|
|
||||||
f"{class_path} must be inherit from "
|
|
||||||
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
|
|
||||||
)
|
|
||||||
# TODO(simon): implement uri support.
|
|
||||||
_ = plugin_class.validate(runtime_env)
|
|
||||||
# Validation passed, add the entry to parsed runtime env.
|
|
||||||
self["plugins"][class_path] = plugin_field
|
|
||||||
|
|
||||||
unknown_fields = set(runtime_env.keys()) - RuntimeEnv.known_fields
|
|
||||||
if len(unknown_fields):
|
|
||||||
logger.warning(
|
|
||||||
"The following unknown entries in the runtime_env dictionary "
|
|
||||||
f"will be ignored: {unknown_fields}. If you intended to use "
|
|
||||||
"them as plugins, they must be nested in the `plugins` field."
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(architkulkarni): This allows worker caching code in C++ to check
|
# NOTE(architkulkarni): This allows worker caching code in C++ to check
|
||||||
# if a runtime env is empty without deserializing it. This is a catch-
|
# if a runtime env is empty without deserializing it. This is a catch-
|
||||||
|
@ -455,20 +430,17 @@ class RuntimeEnv(dict):
|
||||||
return plugin_uris
|
return plugin_uris
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Any) -> None:
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
if key not in RuntimeEnv.known_fields:
|
# TODO(SongGuyang): Validate the schemas of plugins by json schema.
|
||||||
logger.warning(
|
|
||||||
"The following unknown entries in the runtime_env dictionary "
|
|
||||||
f"will be ignored: {key}. If you intended to use "
|
|
||||||
"them as plugins, they must be nested in the `plugins` field."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
res_value = value
|
res_value = value
|
||||||
if key in OPTION_TO_VALIDATION_FN:
|
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
|
||||||
res_value = OPTION_TO_VALIDATION_FN[key](value)
|
res_value = OPTION_TO_VALIDATION_FN[key](value)
|
||||||
if res_value is None:
|
if res_value is None:
|
||||||
return
|
return
|
||||||
return super().__setitem__(key, res_value)
|
return super().__setitem__(key, res_value)
|
||||||
|
|
||||||
|
def set(self, name: str, value: Any) -> None:
|
||||||
|
self.__setitem__(name, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
|
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
|
||||||
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
|
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
|
||||||
|
@ -655,16 +627,11 @@ class RuntimeEnv(dict):
|
||||||
return None
|
return None
|
||||||
return self["container"].get("run_options", [])
|
return self["container"].get("run_options", [])
|
||||||
|
|
||||||
def has_plugins(self) -> bool:
|
|
||||||
if self.get("plugins"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def plugins(self) -> List[Tuple[str, str]]:
|
def plugins(self) -> List[Tuple[str, str]]:
|
||||||
result = list()
|
result = list()
|
||||||
if self.has_plugins():
|
for key, value in self.items():
|
||||||
for class_path, plugin_field in self["plugins"].items():
|
if key not in self.known_fields:
|
||||||
result.append((class_path, json.dumps(plugin_field, sort_keys=True)))
|
result.append((key, json.dumps(value, sort_keys=True)))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _build_proto_pip_runtime_env(self, runtime_env: ProtoRuntimeEnv):
|
def _build_proto_pip_runtime_env(self, runtime_env: ProtoRuntimeEnv):
|
||||||
|
@ -720,8 +687,7 @@ class RuntimeEnv(dict):
|
||||||
|
|
||||||
def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv):
|
def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv):
|
||||||
"""Construct plugin runtime env protobuf from runtime env dict."""
|
"""Construct plugin runtime env protobuf from runtime env dict."""
|
||||||
if self.has_plugins():
|
for class_path, plugin_field in self.plugins():
|
||||||
for class_path, plugin_field in self.plugins():
|
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
|
||||||
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
|
plugin.class_path = class_path
|
||||||
plugin.class_path = class_path
|
plugin.config = plugin_field
|
||||||
plugin.config = plugin_field
|
|
||||||
|
|
|
@ -923,3 +923,13 @@ def start_http_proxy(request):
|
||||||
if proxy:
|
if proxy:
|
||||||
proxy.terminate()
|
proxy.terminate()
|
||||||
proxy.wait()
|
proxy.wait()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def set_runtime_env_plugins(request):
|
||||||
|
runtime_env_plugins = getattr(request, "param", "0")
|
||||||
|
try:
|
||||||
|
os.environ["RAY_RUNTIME_ENV_PLUGINS"] = runtime_env_plugins
|
||||||
|
yield runtime_env_plugins
|
||||||
|
finally:
|
||||||
|
del os.environ["RAY_RUNTIME_ENV_PLUGINS"]
|
||||||
|
|
|
@ -19,9 +19,13 @@ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
|
||||||
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH = (
|
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH = (
|
||||||
"ray.tests.test_placement_group_4.MockWorkerStartupSlowlyPlugin" # noqa
|
"ray.tests.test_placement_group_4.MockWorkerStartupSlowlyPlugin" # noqa
|
||||||
)
|
)
|
||||||
|
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME = "MockWorkerStartupSlowlyPlugin"
|
||||||
|
|
||||||
|
|
||||||
class MockWorkerStartupSlowlyPlugin(RuntimeEnvPlugin):
|
class MockWorkerStartupSlowlyPlugin(RuntimeEnvPlugin):
|
||||||
|
|
||||||
|
name = MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME
|
||||||
|
|
||||||
def validate(runtime_env_dict: dict) -> str:
|
def validate(runtime_env_dict: dict) -> str:
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
@ -110,7 +114,16 @@ def test_remove_placement_group(ray_start_cluster, connect_to_client):
|
||||||
ray.get(task_ref)
|
ray.get(task_ref)
|
||||||
|
|
||||||
|
|
||||||
def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
|
@pytest.mark.parametrize(
|
||||||
|
"set_runtime_env_plugins",
|
||||||
|
[
|
||||||
|
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH,
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
def test_remove_placement_group_worker_startup_slowly(
|
||||||
|
set_runtime_env_plugins, ray_start_cluster
|
||||||
|
):
|
||||||
cluster = ray_start_cluster
|
cluster = ray_start_cluster
|
||||||
cluster.add_node(num_cpus=4)
|
cluster.add_node(num_cpus=4)
|
||||||
ray.init(address=cluster.address)
|
ray.init(address=cluster.address)
|
||||||
|
@ -134,7 +147,7 @@ def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
|
||||||
# runtime env to mock worker start up slowly.
|
# runtime env to mock worker start up slowly.
|
||||||
task_ref = long_running_task.options(
|
task_ref = long_running_task.options(
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
runtime_env={"plugins": {MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH: {}}},
|
runtime_env={MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME: {}},
|
||||||
).remote()
|
).remote()
|
||||||
a = A.options(placement_group=placement_group).remote()
|
a = A.options(placement_group=placement_group).remote()
|
||||||
assert ray.get(a.f.remote()) == 3
|
assert ray.get(a.f.remote()) == 3
|
||||||
|
|
|
@ -559,6 +559,7 @@ def test_to_make_ensure_runtime_env_api(start_cluster):
|
||||||
|
|
||||||
|
|
||||||
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin"
|
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin"
|
||||||
|
MY_PLUGIN_NAME = "MyPlugin"
|
||||||
success_retry_number = 3
|
success_retry_number = 3
|
||||||
runtime_env_retry_times = 0
|
runtime_env_retry_times = 0
|
||||||
|
|
||||||
|
@ -566,9 +567,12 @@ runtime_env_retry_times = 0
|
||||||
# This plugin can make runtime env creation failed before the retry times
|
# This plugin can make runtime env creation failed before the retry times
|
||||||
# increased to `success_retry_number`.
|
# increased to `success_retry_number`.
|
||||||
class MyPlugin(RuntimeEnvPlugin):
|
class MyPlugin(RuntimeEnvPlugin):
|
||||||
|
|
||||||
|
name = MY_PLUGIN_NAME
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(runtime_env_dict: dict) -> str:
|
def validate(runtime_env_dict: dict) -> str:
|
||||||
return runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
|
return runtime_env_dict[MY_PLUGIN_NAME]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def modify_context(
|
def modify_context(
|
||||||
|
@ -592,7 +596,16 @@ class MyPlugin(RuntimeEnvPlugin):
|
||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
|
@pytest.mark.parametrize(
|
||||||
|
"set_runtime_env_plugins",
|
||||||
|
[
|
||||||
|
MY_PLUGIN_CLASS_PATH,
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
def test_runtime_env_retry(
|
||||||
|
set_runtime_env_retry_times, set_runtime_env_plugins, ray_start_regular
|
||||||
|
):
|
||||||
@ray.remote
|
@ray.remote
|
||||||
def f():
|
def f():
|
||||||
return "ok"
|
return "ok"
|
||||||
|
@ -601,9 +614,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
|
||||||
if runtime_env_retry_times >= success_retry_number:
|
if runtime_env_retry_times >= success_retry_number:
|
||||||
# Enough retry times
|
# Enough retry times
|
||||||
output = ray.get(
|
output = ray.get(
|
||||||
f.options(
|
f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote()
|
||||||
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
|
|
||||||
).remote()
|
|
||||||
)
|
)
|
||||||
assert output == "ok"
|
assert output == "ok"
|
||||||
else:
|
else:
|
||||||
|
@ -611,11 +622,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
RuntimeEnvSetupError, match=f"Fault injection {runtime_env_retry_times}"
|
RuntimeEnvSetupError, match=f"Fault injection {runtime_env_retry_times}"
|
||||||
):
|
):
|
||||||
ray.get(
|
ray.get(f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote())
|
||||||
f.options(
|
|
||||||
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
|
|
||||||
).remote()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
@ -13,14 +13,16 @@ from ray._private.test_utils import test_external_redis, wait_for_condition
|
||||||
from ray.exceptions import RuntimeEnvSetupError
|
from ray.exceptions import RuntimeEnvSetupError
|
||||||
|
|
||||||
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
|
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
|
||||||
|
MY_PLUGIN_NAME = "MyPlugin"
|
||||||
|
|
||||||
|
|
||||||
class MyPlugin(RuntimeEnvPlugin):
|
class MyPlugin(RuntimeEnvPlugin):
|
||||||
|
name = MY_PLUGIN_NAME
|
||||||
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
|
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(runtime_env_dict: dict) -> str:
|
def validate(runtime_env_dict: dict) -> str:
|
||||||
value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
|
value = runtime_env_dict[MY_PLUGIN_NAME]
|
||||||
if value == "fail":
|
if value == "fail":
|
||||||
raise ValueError("not allowed")
|
raise ValueError("not allowed")
|
||||||
return value
|
return value
|
||||||
|
@ -42,7 +44,14 @@ class MyPlugin(RuntimeEnvPlugin):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_simple_env_modification_plugin(ray_start_regular):
|
@pytest.mark.parametrize(
|
||||||
|
"set_runtime_env_plugins",
|
||||||
|
[
|
||||||
|
MY_PLUGIN_CLASS_PATH,
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
def test_simple_env_modification_plugin(set_runtime_env_plugins, ray_start_regular):
|
||||||
_, tmp_file_path = tempfile.mkstemp()
|
_, tmp_file_path = tempfile.mkstemp()
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -57,21 +66,19 @@ def test_simple_env_modification_plugin(ray_start_regular):
|
||||||
"nice": psutil.Process().nice(),
|
"nice": psutil.Process().nice(),
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not allowed"):
|
with pytest.raises(RuntimeEnvSetupError, match="not allowed"):
|
||||||
f.options(runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: "fail"}}).remote()
|
ray.get(f.options(runtime_env={MY_PLUGIN_NAME: "fail"}).remote())
|
||||||
|
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
output = ray.get(
|
output = ray.get(
|
||||||
f.options(
|
f.options(
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"plugins": {
|
MY_PLUGIN_NAME: {
|
||||||
MY_PLUGIN_CLASS_PATH: {
|
"env_value": 42,
|
||||||
"env_value": 42,
|
"tmp_file": tmp_file_path,
|
||||||
"tmp_file": tmp_file_path,
|
"tmp_content": "hello",
|
||||||
"tmp_content": "hello",
|
# See https://en.wikipedia.org/wiki/Nice_(Unix)
|
||||||
# See https://en.wikipedia.org/wiki/Nice_(Unix)
|
"prefix_command": "nice -n 19",
|
||||||
"prefix_command": "nice -n 19",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).remote()
|
).remote()
|
||||||
|
@ -81,11 +88,13 @@ def test_simple_env_modification_plugin(ray_start_regular):
|
||||||
|
|
||||||
|
|
||||||
MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang"
|
MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang"
|
||||||
|
MY_PLUGIN_FOR_HANG_NAME = "MyPluginForHang"
|
||||||
my_plugin_setup_times = 0
|
my_plugin_setup_times = 0
|
||||||
|
|
||||||
|
|
||||||
# This plugin will hang when first setup, second setup will ok
|
# This plugin will hang when first setup, second setup will ok
|
||||||
class MyPluginForHang(RuntimeEnvPlugin):
|
class MyPluginForHang(RuntimeEnvPlugin):
|
||||||
|
name = MY_PLUGIN_FOR_HANG_NAME
|
||||||
env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY"
|
env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -112,7 +121,14 @@ class MyPluginForHang(RuntimeEnvPlugin):
|
||||||
ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times)
|
ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times)
|
||||||
|
|
||||||
|
|
||||||
def test_plugin_hang(ray_start_regular):
|
@pytest.mark.parametrize(
|
||||||
|
"set_runtime_env_plugins",
|
||||||
|
[
|
||||||
|
MY_PLUGIN_FOR_HANG_CLASS_PATH,
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
def test_plugin_hang(set_runtime_env_plugins, ray_start_regular):
|
||||||
env_key = MyPluginForHang.env_key
|
env_key = MyPluginForHang.env_key
|
||||||
|
|
||||||
@ray.remote(num_cpus=0.1)
|
@ray.remote(num_cpus=0.1)
|
||||||
|
@ -122,11 +138,9 @@ def test_plugin_hang(ray_start_regular):
|
||||||
refs = [
|
refs = [
|
||||||
f.options(
|
f.options(
|
||||||
# Avoid hitting the cache of runtime_env
|
# Avoid hitting the cache of runtime_env
|
||||||
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f1"}}}
|
runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f1"}}
|
||||||
).remote(),
|
|
||||||
f.options(
|
|
||||||
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f2"}}}
|
|
||||||
).remote(),
|
).remote(),
|
||||||
|
f.options(runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f2"}}).remote(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def condition():
|
def condition():
|
||||||
|
@ -145,19 +159,26 @@ def test_plugin_hang(ray_start_regular):
|
||||||
|
|
||||||
|
|
||||||
DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
|
DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
|
||||||
|
DUMMY_PLUGIN_NAME = "DummyPlugin"
|
||||||
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
|
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
|
||||||
|
HANG_PLUGIN_NAME = "HangPlugin"
|
||||||
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
|
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
|
||||||
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
|
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
|
||||||
)
|
)
|
||||||
|
DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout"
|
||||||
|
|
||||||
|
|
||||||
class DummyPlugin(RuntimeEnvPlugin):
|
class DummyPlugin(RuntimeEnvPlugin):
|
||||||
|
name = DUMMY_PLUGIN_NAME
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(runtime_env_dict: dict) -> str:
|
def validate(runtime_env_dict: dict) -> str:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
class HangPlugin(DummyPlugin):
|
class HangPlugin(DummyPlugin):
|
||||||
|
name = HANG_PLUGIN_NAME
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||||
) -> float:
|
) -> float:
|
||||||
|
@ -165,14 +186,25 @@ class HangPlugin(DummyPlugin):
|
||||||
|
|
||||||
|
|
||||||
class DiasbleTimeoutPlugin(DummyPlugin):
|
class DiasbleTimeoutPlugin(DummyPlugin):
|
||||||
|
name = DISABLE_TIMEOUT_PLUGIN_NAME
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||||
) -> float:
|
) -> float:
|
||||||
sleep(10)
|
sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"set_runtime_env_plugins",
|
||||||
|
[
|
||||||
|
f"{DUMMY_PLUGIN_CLASS_PATH},"
|
||||||
|
f"{HANG_PLUGIN_CLASS_PATH},"
|
||||||
|
f"{DISABLE_TIMEOUT_PLUGIN_CLASS_PATH}",
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
@pytest.mark.skipif(test_external_redis(), reason="Failing in redis mode.")
|
@pytest.mark.skipif(test_external_redis(), reason="Failing in redis mode.")
|
||||||
def test_plugin_timeout(start_cluster):
|
def test_plugin_timeout(set_runtime_env_plugins, start_cluster):
|
||||||
@ray.remote(num_cpus=0.1)
|
@ray.remote(num_cpus=0.1)
|
||||||
def f():
|
def f():
|
||||||
return True
|
return True
|
||||||
|
@ -180,20 +212,14 @@ def test_plugin_timeout(start_cluster):
|
||||||
refs = [
|
refs = [
|
||||||
f.options(
|
f.options(
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"plugins": {
|
HANG_PLUGIN_NAME: {"name": "f1"},
|
||||||
HANG_PLUGIN_CLASS_PATH: {"name": "f1"},
|
|
||||||
},
|
|
||||||
"config": {"setup_timeout_seconds": 10},
|
"config": {"setup_timeout_seconds": 10},
|
||||||
}
|
}
|
||||||
).remote(),
|
).remote(),
|
||||||
f.options(
|
f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(),
|
||||||
runtime_env={"plugins": {DUMMY_PLUGIN_CLASS_PATH: {"name": "f2"}}}
|
|
||||||
).remote(),
|
|
||||||
f.options(
|
f.options(
|
||||||
runtime_env={
|
runtime_env={
|
||||||
"plugins": {
|
HANG_PLUGIN_NAME: {"name": "f3"},
|
||||||
HANG_PLUGIN_CLASS_PATH: {"name": "f3"},
|
|
||||||
},
|
|
||||||
"config": {"setup_timeout_seconds": -1},
|
"config": {"setup_timeout_seconds": -1},
|
||||||
}
|
}
|
||||||
).remote(),
|
).remote(),
|
||||||
|
|
Loading…
Add table
Reference in a new issue