mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[runtime env] plugin refactor[1/n] (#26077)
This commit is contained in:
parent
876fef0fcf
commit
58bfad84d3
8 changed files with 178 additions and 110 deletions
|
@ -20,10 +20,10 @@ from ray._private.runtime_env.context import RuntimeEnvContext
|
|||
from ray._private.runtime_env.java_jars import JavaJarsPlugin
|
||||
from ray._private.runtime_env.pip import PipPlugin
|
||||
from ray._private.runtime_env.plugin import PluginCacheManager
|
||||
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
|
||||
from ray._private.runtime_env.py_modules import PyModulesPlugin
|
||||
from ray._private.runtime_env.uri_cache import URICache
|
||||
from ray._private.runtime_env.working_dir import WorkingDirPlugin
|
||||
from ray._private.utils import import_attr
|
||||
from ray.core.generated import (
|
||||
agent_manager_pb2,
|
||||
runtime_env_agent_pb2,
|
||||
|
@ -224,6 +224,7 @@ class RuntimeEnvAgent(
|
|||
self.unused_uris_processor,
|
||||
self.unused_runtime_env_processor,
|
||||
)
|
||||
self._runtime_env_plugin_manager = RuntimeEnvPluginManager()
|
||||
|
||||
self._logger = default_logger
|
||||
|
||||
|
@ -295,13 +296,13 @@ class RuntimeEnvAgent(
|
|||
|
||||
def setup_plugins():
|
||||
# Run setup function from all the plugins
|
||||
for plugin_class_path, config in runtime_env.plugins():
|
||||
per_job_logger.debug(
|
||||
f"Setting up runtime env plugin {plugin_class_path}"
|
||||
)
|
||||
plugin_class = import_attr(plugin_class_path)
|
||||
plugin = plugin_class()
|
||||
for name, config in runtime_env.plugins():
|
||||
per_job_logger.debug(f"Setting up runtime env plugin {name}")
|
||||
plugin = self._runtime_env_plugin_manager.get_plugin(name)
|
||||
if plugin is None:
|
||||
raise RuntimeError(f"runtime env plugin {name} not found.")
|
||||
# TODO(architkulkarni): implement uri support
|
||||
plugin.validate(runtime_env)
|
||||
plugin.create("uri not implemented", json.loads(config), context)
|
||||
plugin.modify_context(
|
||||
"uri not implemented",
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
# Env var set by job manager to pass runtime env and metadata to subprocess
|
||||
RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"
|
||||
|
||||
# The plugins which should be loaded when ray cluster starts.
|
||||
RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS"
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import logging
|
||||
import os
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from ray._private.runtime_env.context import RuntimeEnvContext
|
||||
from ray._private.runtime_env.uri_cache import URICache
|
||||
from ray._private.runtime_env.constants import RAY_RUNTIME_ENV_PLUGINS_ENV_VAR
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray._private.utils import import_attr
|
||||
|
||||
default_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -13,7 +16,7 @@ default_logger = logging.getLogger(__name__)
|
|||
class RuntimeEnvPlugin(ABC):
|
||||
"""Abstract base class for runtime environment plugins."""
|
||||
|
||||
name: str
|
||||
name: str = None
|
||||
|
||||
@staticmethod
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
|
@ -88,6 +91,45 @@ class RuntimeEnvPlugin(ABC):
|
|||
return 0
|
||||
|
||||
|
||||
class RuntimeEnvPluginManager:
|
||||
"""This mananger is used to load plugins in runtime env agent."""
|
||||
|
||||
def __init__(self):
|
||||
self.plugins = {}
|
||||
plugins_config = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
|
||||
if plugins_config:
|
||||
self.load_plugins(plugins_config.split(","))
|
||||
|
||||
def load_plugins(self, plugin_classes: List[str]):
|
||||
"""Load runtime env plugins"""
|
||||
for plugin_class_path in plugin_classes:
|
||||
plugin_class = import_attr(plugin_class_path)
|
||||
if not issubclass(plugin_class, RuntimeEnvPlugin):
|
||||
default_logger.warning(
|
||||
"Invalid runtime env plugin class %s. "
|
||||
"The plugin class must inherit "
|
||||
"ray._private.runtime_env.plugin.RuntimeEnvPlugin.",
|
||||
plugin_class,
|
||||
)
|
||||
continue
|
||||
if not plugin_class.name:
|
||||
default_logger.warning(
|
||||
"No valid name in runtime env plugin %s", plugin_class
|
||||
)
|
||||
continue
|
||||
if plugin_class.name in self.plugins:
|
||||
default_logger.warning(
|
||||
"The name of runtime env plugin %s conflicts with %s",
|
||||
plugin_class,
|
||||
self.plugins[plugin_class.name],
|
||||
)
|
||||
continue
|
||||
self.plugins[plugin_class.name] = plugin_class()
|
||||
|
||||
def get_plugin(self, name: str):
|
||||
return self.plugins.get(name)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PluginCacheManager:
|
||||
"""Manages a plugin and a cache for its local resources."""
|
||||
|
|
|
@ -10,9 +10,7 @@ import ray
|
|||
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
|
||||
from ray._private.runtime_env.conda import get_uri as get_conda_uri
|
||||
from ray._private.runtime_env.pip import get_uri as get_pip_uri
|
||||
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
|
||||
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
|
||||
from ray._private.utils import import_attr
|
||||
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
|
||||
from ray.core.generated.runtime_env_common_pb2 import (
|
||||
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
|
||||
|
@ -85,9 +83,8 @@ def _parse_proto_plugin_runtime_env(
|
|||
):
|
||||
"""Parse plugin runtime env protobuf to runtime env dict."""
|
||||
if runtime_env.python_runtime_env.HasField("plugin_runtime_env"):
|
||||
runtime_env_dict["plugins"] = dict()
|
||||
for plugin in runtime_env.python_runtime_env.plugin_runtime_env.plugins:
|
||||
runtime_env_dict["plugins"][plugin.class_path] = json.loads(plugin.config)
|
||||
runtime_env_dict[plugin.class_path] = json.loads(plugin.config)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
@ -319,8 +316,12 @@ class RuntimeEnv(dict):
|
|||
"_ray_release",
|
||||
"_ray_commit",
|
||||
"_inject_current_ray",
|
||||
"plugins",
|
||||
"config",
|
||||
# TODO(SongGuyang): We add this because the test
|
||||
# `test_experimental_package_github` set a `docker`
|
||||
# field which is not supported. We should remove it
|
||||
# with the test.
|
||||
"docker",
|
||||
}
|
||||
|
||||
extensions_fields: Set[str] = {
|
||||
|
@ -363,19 +364,20 @@ class RuntimeEnv(dict):
|
|||
if runtime_env.get("java_jars"):
|
||||
runtime_env["java_jars"] = runtime_env.get("java_jars")
|
||||
|
||||
self.update(runtime_env)
|
||||
|
||||
# Blindly trust that the runtime_env has already been validated.
|
||||
# This is dangerous and should only be used internally (e.g., on the
|
||||
# deserialization codepath.
|
||||
if not _validate:
|
||||
self.update(runtime_env)
|
||||
return
|
||||
|
||||
if runtime_env.get("conda") and runtime_env.get("pip"):
|
||||
if self.get("conda") and self.get("pip"):
|
||||
raise ValueError(
|
||||
"The 'pip' field and 'conda' field of "
|
||||
"runtime_env cannot both be specified.\n"
|
||||
f"specified pip field: {runtime_env['pip']}\n"
|
||||
f"specified conda field: {runtime_env['conda']}\n"
|
||||
f"specified pip field: {self['pip']}\n"
|
||||
f"specified conda field: {self['conda']}\n"
|
||||
"To use pip with conda, please only set the 'conda' "
|
||||
"field, and specify your pip dependencies "
|
||||
"within the conda YAML config dict: see "
|
||||
|
@ -385,48 +387,21 @@ class RuntimeEnv(dict):
|
|||
)
|
||||
|
||||
for option, validate_fn in OPTION_TO_VALIDATION_FN.items():
|
||||
option_val = runtime_env.get(option)
|
||||
option_val = self.get(option)
|
||||
if option_val is not None:
|
||||
del self[option]
|
||||
self[option] = option_val
|
||||
|
||||
if "_ray_release" in runtime_env:
|
||||
self["_ray_release"] = runtime_env["_ray_release"]
|
||||
|
||||
if "_ray_commit" in runtime_env:
|
||||
self["_ray_commit"] = runtime_env["_ray_commit"]
|
||||
else:
|
||||
if "_ray_commit" not in self:
|
||||
if self.get("pip") or self.get("conda"):
|
||||
self["_ray_commit"] = ray.__commit__
|
||||
|
||||
# Used for testing wheels that have not yet been merged into master.
|
||||
# If this is set to True, then we do not inject Ray into the conda
|
||||
# or pip dependencies.
|
||||
if "_inject_current_ray" in runtime_env:
|
||||
self["_inject_current_ray"] = runtime_env["_inject_current_ray"]
|
||||
elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
|
||||
self["_inject_current_ray"] = True
|
||||
if "plugins" in runtime_env:
|
||||
self["plugins"] = dict()
|
||||
for class_path, plugin_field in runtime_env["plugins"].items():
|
||||
plugin_class: RuntimeEnvPlugin = import_attr(class_path)
|
||||
if not issubclass(plugin_class, RuntimeEnvPlugin):
|
||||
# TODO(simon): move the inferface to public once ready.
|
||||
raise TypeError(
|
||||
f"{class_path} must be inherit from "
|
||||
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
|
||||
)
|
||||
# TODO(simon): implement uri support.
|
||||
_ = plugin_class.validate(runtime_env)
|
||||
# Validation passed, add the entry to parsed runtime env.
|
||||
self["plugins"][class_path] = plugin_field
|
||||
|
||||
unknown_fields = set(runtime_env.keys()) - RuntimeEnv.known_fields
|
||||
if len(unknown_fields):
|
||||
logger.warning(
|
||||
"The following unknown entries in the runtime_env dictionary "
|
||||
f"will be ignored: {unknown_fields}. If you intended to use "
|
||||
"them as plugins, they must be nested in the `plugins` field."
|
||||
)
|
||||
if "_inject_current_ray" not in self:
|
||||
if "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
|
||||
self["_inject_current_ray"] = True
|
||||
|
||||
# NOTE(architkulkarni): This allows worker caching code in C++ to check
|
||||
# if a runtime env is empty without deserializing it. This is a catch-
|
||||
|
@ -455,20 +430,17 @@ class RuntimeEnv(dict):
|
|||
return plugin_uris
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if key not in RuntimeEnv.known_fields:
|
||||
logger.warning(
|
||||
"The following unknown entries in the runtime_env dictionary "
|
||||
f"will be ignored: {key}. If you intended to use "
|
||||
"them as plugins, they must be nested in the `plugins` field."
|
||||
)
|
||||
return
|
||||
# TODO(SongGuyang): Validate the schemas of plugins by json schema.
|
||||
res_value = value
|
||||
if key in OPTION_TO_VALIDATION_FN:
|
||||
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
|
||||
res_value = OPTION_TO_VALIDATION_FN[key](value)
|
||||
if res_value is None:
|
||||
return
|
||||
return super().__setitem__(key, res_value)
|
||||
|
||||
def set(self, name: str, value: Any) -> None:
|
||||
self.__setitem__(name, value)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
|
||||
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
|
||||
|
@ -655,16 +627,11 @@ class RuntimeEnv(dict):
|
|||
return None
|
||||
return self["container"].get("run_options", [])
|
||||
|
||||
def has_plugins(self) -> bool:
|
||||
if self.get("plugins"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def plugins(self) -> List[Tuple[str, str]]:
|
||||
result = list()
|
||||
if self.has_plugins():
|
||||
for class_path, plugin_field in self["plugins"].items():
|
||||
result.append((class_path, json.dumps(plugin_field, sort_keys=True)))
|
||||
for key, value in self.items():
|
||||
if key not in self.known_fields:
|
||||
result.append((key, json.dumps(value, sort_keys=True)))
|
||||
return result
|
||||
|
||||
def _build_proto_pip_runtime_env(self, runtime_env: ProtoRuntimeEnv):
|
||||
|
@ -720,8 +687,7 @@ class RuntimeEnv(dict):
|
|||
|
||||
def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv):
|
||||
"""Construct plugin runtime env protobuf from runtime env dict."""
|
||||
if self.has_plugins():
|
||||
for class_path, plugin_field in self.plugins():
|
||||
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
|
||||
plugin.class_path = class_path
|
||||
plugin.config = plugin_field
|
||||
for class_path, plugin_field in self.plugins():
|
||||
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
|
||||
plugin.class_path = class_path
|
||||
plugin.config = plugin_field
|
||||
|
|
|
@ -923,3 +923,13 @@ def start_http_proxy(request):
|
|||
if proxy:
|
||||
proxy.terminate()
|
||||
proxy.wait()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_runtime_env_plugins(request):
|
||||
runtime_env_plugins = getattr(request, "param", "0")
|
||||
try:
|
||||
os.environ["RAY_RUNTIME_ENV_PLUGINS"] = runtime_env_plugins
|
||||
yield runtime_env_plugins
|
||||
finally:
|
||||
del os.environ["RAY_RUNTIME_ENV_PLUGINS"]
|
||||
|
|
|
@ -19,9 +19,13 @@ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
|
|||
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH = (
|
||||
"ray.tests.test_placement_group_4.MockWorkerStartupSlowlyPlugin" # noqa
|
||||
)
|
||||
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME = "MockWorkerStartupSlowlyPlugin"
|
||||
|
||||
|
||||
class MockWorkerStartupSlowlyPlugin(RuntimeEnvPlugin):
|
||||
|
||||
name = MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME
|
||||
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
return "success"
|
||||
|
||||
|
@ -110,7 +114,16 @@ def test_remove_placement_group(ray_start_cluster, connect_to_client):
|
|||
ray.get(task_ref)
|
||||
|
||||
|
||||
def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH,
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_remove_placement_group_worker_startup_slowly(
|
||||
set_runtime_env_plugins, ray_start_cluster
|
||||
):
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=4)
|
||||
ray.init(address=cluster.address)
|
||||
|
@ -134,7 +147,7 @@ def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
|
|||
# runtime env to mock worker start up slowly.
|
||||
task_ref = long_running_task.options(
|
||||
placement_group=placement_group,
|
||||
runtime_env={"plugins": {MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH: {}}},
|
||||
runtime_env={MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME: {}},
|
||||
).remote()
|
||||
a = A.options(placement_group=placement_group).remote()
|
||||
assert ray.get(a.f.remote()) == 3
|
||||
|
|
|
@ -559,6 +559,7 @@ def test_to_make_ensure_runtime_env_api(start_cluster):
|
|||
|
||||
|
||||
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin"
|
||||
MY_PLUGIN_NAME = "MyPlugin"
|
||||
success_retry_number = 3
|
||||
runtime_env_retry_times = 0
|
||||
|
||||
|
@ -566,9 +567,12 @@ runtime_env_retry_times = 0
|
|||
# This plugin can make runtime env creation failed before the retry times
|
||||
# increased to `success_retry_number`.
|
||||
class MyPlugin(RuntimeEnvPlugin):
|
||||
|
||||
name = MY_PLUGIN_NAME
|
||||
|
||||
@staticmethod
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
return runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
|
||||
return runtime_env_dict[MY_PLUGIN_NAME]
|
||||
|
||||
@staticmethod
|
||||
def modify_context(
|
||||
|
@ -592,7 +596,16 @@ class MyPlugin(RuntimeEnvPlugin):
|
|||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
MY_PLUGIN_CLASS_PATH,
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_runtime_env_retry(
|
||||
set_runtime_env_retry_times, set_runtime_env_plugins, ray_start_regular
|
||||
):
|
||||
@ray.remote
|
||||
def f():
|
||||
return "ok"
|
||||
|
@ -601,9 +614,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
|
|||
if runtime_env_retry_times >= success_retry_number:
|
||||
# Enough retry times
|
||||
output = ray.get(
|
||||
f.options(
|
||||
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
|
||||
).remote()
|
||||
f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote()
|
||||
)
|
||||
assert output == "ok"
|
||||
else:
|
||||
|
@ -611,11 +622,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
|
|||
with pytest.raises(
|
||||
RuntimeEnvSetupError, match=f"Fault injection {runtime_env_retry_times}"
|
||||
):
|
||||
ray.get(
|
||||
f.options(
|
||||
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
|
||||
).remote()
|
||||
)
|
||||
ray.get(f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -13,14 +13,16 @@ from ray._private.test_utils import test_external_redis, wait_for_condition
|
|||
from ray.exceptions import RuntimeEnvSetupError
|
||||
|
||||
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
|
||||
MY_PLUGIN_NAME = "MyPlugin"
|
||||
|
||||
|
||||
class MyPlugin(RuntimeEnvPlugin):
|
||||
name = MY_PLUGIN_NAME
|
||||
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
|
||||
|
||||
@staticmethod
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
|
||||
value = runtime_env_dict[MY_PLUGIN_NAME]
|
||||
if value == "fail":
|
||||
raise ValueError("not allowed")
|
||||
return value
|
||||
|
@ -42,7 +44,14 @@ class MyPlugin(RuntimeEnvPlugin):
|
|||
)
|
||||
|
||||
|
||||
def test_simple_env_modification_plugin(ray_start_regular):
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
MY_PLUGIN_CLASS_PATH,
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_simple_env_modification_plugin(set_runtime_env_plugins, ray_start_regular):
|
||||
_, tmp_file_path = tempfile.mkstemp()
|
||||
|
||||
@ray.remote
|
||||
|
@ -57,21 +66,19 @@ def test_simple_env_modification_plugin(ray_start_regular):
|
|||
"nice": psutil.Process().nice(),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
f.options(runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: "fail"}}).remote()
|
||||
with pytest.raises(RuntimeEnvSetupError, match="not allowed"):
|
||||
ray.get(f.options(runtime_env={MY_PLUGIN_NAME: "fail"}).remote())
|
||||
|
||||
if os.name != "nt":
|
||||
output = ray.get(
|
||||
f.options(
|
||||
runtime_env={
|
||||
"plugins": {
|
||||
MY_PLUGIN_CLASS_PATH: {
|
||||
"env_value": 42,
|
||||
"tmp_file": tmp_file_path,
|
||||
"tmp_content": "hello",
|
||||
# See https://en.wikipedia.org/wiki/Nice_(Unix)
|
||||
"prefix_command": "nice -n 19",
|
||||
}
|
||||
MY_PLUGIN_NAME: {
|
||||
"env_value": 42,
|
||||
"tmp_file": tmp_file_path,
|
||||
"tmp_content": "hello",
|
||||
# See https://en.wikipedia.org/wiki/Nice_(Unix)
|
||||
"prefix_command": "nice -n 19",
|
||||
}
|
||||
}
|
||||
).remote()
|
||||
|
@ -81,11 +88,13 @@ def test_simple_env_modification_plugin(ray_start_regular):
|
|||
|
||||
|
||||
MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang"
|
||||
MY_PLUGIN_FOR_HANG_NAME = "MyPluginForHang"
|
||||
my_plugin_setup_times = 0
|
||||
|
||||
|
||||
# This plugin will hang when first setup, second setup will ok
|
||||
class MyPluginForHang(RuntimeEnvPlugin):
|
||||
name = MY_PLUGIN_FOR_HANG_NAME
|
||||
env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY"
|
||||
|
||||
@staticmethod
|
||||
|
@ -112,7 +121,14 @@ class MyPluginForHang(RuntimeEnvPlugin):
|
|||
ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times)
|
||||
|
||||
|
||||
def test_plugin_hang(ray_start_regular):
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
MY_PLUGIN_FOR_HANG_CLASS_PATH,
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_plugin_hang(set_runtime_env_plugins, ray_start_regular):
|
||||
env_key = MyPluginForHang.env_key
|
||||
|
||||
@ray.remote(num_cpus=0.1)
|
||||
|
@ -122,11 +138,9 @@ def test_plugin_hang(ray_start_regular):
|
|||
refs = [
|
||||
f.options(
|
||||
# Avoid hitting the cache of runtime_env
|
||||
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f1"}}}
|
||||
).remote(),
|
||||
f.options(
|
||||
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f2"}}}
|
||||
runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f1"}}
|
||||
).remote(),
|
||||
f.options(runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f2"}}).remote(),
|
||||
]
|
||||
|
||||
def condition():
|
||||
|
@ -145,19 +159,26 @@ def test_plugin_hang(ray_start_regular):
|
|||
|
||||
|
||||
DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
|
||||
DUMMY_PLUGIN_NAME = "DummyPlugin"
|
||||
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
|
||||
HANG_PLUGIN_NAME = "HangPlugin"
|
||||
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
|
||||
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
|
||||
)
|
||||
DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout"
|
||||
|
||||
|
||||
class DummyPlugin(RuntimeEnvPlugin):
|
||||
name = DUMMY_PLUGIN_NAME
|
||||
|
||||
@staticmethod
|
||||
def validate(runtime_env_dict: dict) -> str:
|
||||
return 1
|
||||
|
||||
|
||||
class HangPlugin(DummyPlugin):
|
||||
name = HANG_PLUGIN_NAME
|
||||
|
||||
def create(
|
||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||
) -> float:
|
||||
|
@ -165,14 +186,25 @@ class HangPlugin(DummyPlugin):
|
|||
|
||||
|
||||
class DiasbleTimeoutPlugin(DummyPlugin):
|
||||
name = DISABLE_TIMEOUT_PLUGIN_NAME
|
||||
|
||||
def create(
|
||||
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
|
||||
) -> float:
|
||||
sleep(10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"set_runtime_env_plugins",
|
||||
[
|
||||
f"{DUMMY_PLUGIN_CLASS_PATH},"
|
||||
f"{HANG_PLUGIN_CLASS_PATH},"
|
||||
f"{DISABLE_TIMEOUT_PLUGIN_CLASS_PATH}",
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.skipif(test_external_redis(), reason="Failing in redis mode.")
|
||||
def test_plugin_timeout(start_cluster):
|
||||
def test_plugin_timeout(set_runtime_env_plugins, start_cluster):
|
||||
@ray.remote(num_cpus=0.1)
|
||||
def f():
|
||||
return True
|
||||
|
@ -180,20 +212,14 @@ def test_plugin_timeout(start_cluster):
|
|||
refs = [
|
||||
f.options(
|
||||
runtime_env={
|
||||
"plugins": {
|
||||
HANG_PLUGIN_CLASS_PATH: {"name": "f1"},
|
||||
},
|
||||
HANG_PLUGIN_NAME: {"name": "f1"},
|
||||
"config": {"setup_timeout_seconds": 10},
|
||||
}
|
||||
).remote(),
|
||||
f.options(
|
||||
runtime_env={"plugins": {DUMMY_PLUGIN_CLASS_PATH: {"name": "f2"}}}
|
||||
).remote(),
|
||||
f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(),
|
||||
f.options(
|
||||
runtime_env={
|
||||
"plugins": {
|
||||
HANG_PLUGIN_CLASS_PATH: {"name": "f3"},
|
||||
},
|
||||
HANG_PLUGIN_NAME: {"name": "f3"},
|
||||
"config": {"setup_timeout_seconds": -1},
|
||||
}
|
||||
).remote(),
|
||||
|
|
Loading…
Add table
Reference in a new issue