[runtime env] plugin refactor[1/n] (#26077)

This commit is contained in:
Guyang Song 2022-06-28 14:09:05 +08:00 committed by GitHub
parent 876fef0fcf
commit 58bfad84d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 178 additions and 110 deletions

View file

@ -20,10 +20,10 @@ from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.java_jars import JavaJarsPlugin from ray._private.runtime_env.java_jars import JavaJarsPlugin
from ray._private.runtime_env.pip import PipPlugin from ray._private.runtime_env.pip import PipPlugin
from ray._private.runtime_env.plugin import PluginCacheManager from ray._private.runtime_env.plugin import PluginCacheManager
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
from ray._private.runtime_env.py_modules import PyModulesPlugin from ray._private.runtime_env.py_modules import PyModulesPlugin
from ray._private.runtime_env.uri_cache import URICache from ray._private.runtime_env.uri_cache import URICache
from ray._private.runtime_env.working_dir import WorkingDirPlugin from ray._private.runtime_env.working_dir import WorkingDirPlugin
from ray._private.utils import import_attr
from ray.core.generated import ( from ray.core.generated import (
agent_manager_pb2, agent_manager_pb2,
runtime_env_agent_pb2, runtime_env_agent_pb2,
@ -224,6 +224,7 @@ class RuntimeEnvAgent(
self.unused_uris_processor, self.unused_uris_processor,
self.unused_runtime_env_processor, self.unused_runtime_env_processor,
) )
self._runtime_env_plugin_manager = RuntimeEnvPluginManager()
self._logger = default_logger self._logger = default_logger
@ -295,13 +296,13 @@ class RuntimeEnvAgent(
def setup_plugins(): def setup_plugins():
# Run setup function from all the plugins # Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins(): for name, config in runtime_env.plugins():
per_job_logger.debug( per_job_logger.debug(f"Setting up runtime env plugin {name}")
f"Setting up runtime env plugin {plugin_class_path}" plugin = self._runtime_env_plugin_manager.get_plugin(name)
) if plugin is None:
plugin_class = import_attr(plugin_class_path) raise RuntimeError(f"runtime env plugin {name} not found.")
plugin = plugin_class()
# TODO(architkulkarni): implement uri support # TODO(architkulkarni): implement uri support
plugin.validate(runtime_env)
plugin.create("uri not implemented", json.loads(config), context) plugin.create("uri not implemented", json.loads(config), context)
plugin.modify_context( plugin.modify_context(
"uri not implemented", "uri not implemented",

View file

@ -1,2 +1,5 @@
# Env var set by job manager to pass runtime env and metadata to subprocess # Env var set by job manager to pass runtime env and metadata to subprocess
RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR" RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"
# The plugins which should be loaded when ray cluster starts.
RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS"

View file

@ -1,10 +1,13 @@
import logging import logging
import os
from abc import ABC from abc import ABC
from typing import List from typing import List
from ray._private.runtime_env.context import RuntimeEnvContext from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.uri_cache import URICache from ray._private.runtime_env.uri_cache import URICache
from ray._private.runtime_env.constants import RAY_RUNTIME_ENV_PLUGINS_ENV_VAR
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray._private.utils import import_attr
default_logger = logging.getLogger(__name__) default_logger = logging.getLogger(__name__)
@ -13,7 +16,7 @@ default_logger = logging.getLogger(__name__)
class RuntimeEnvPlugin(ABC): class RuntimeEnvPlugin(ABC):
"""Abstract base class for runtime environment plugins.""" """Abstract base class for runtime environment plugins."""
name: str name: str = None
@staticmethod @staticmethod
def validate(runtime_env_dict: dict) -> str: def validate(runtime_env_dict: dict) -> str:
@ -88,6 +91,45 @@ class RuntimeEnvPlugin(ABC):
return 0 return 0
class RuntimeEnvPluginManager:
"""This mananger is used to load plugins in runtime env agent."""
def __init__(self):
self.plugins = {}
plugins_config = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
if plugins_config:
self.load_plugins(plugins_config.split(","))
def load_plugins(self, plugin_classes: List[str]):
"""Load runtime env plugins"""
for plugin_class_path in plugin_classes:
plugin_class = import_attr(plugin_class_path)
if not issubclass(plugin_class, RuntimeEnvPlugin):
default_logger.warning(
"Invalid runtime env plugin class %s. "
"The plugin class must inherit "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin.",
plugin_class,
)
continue
if not plugin_class.name:
default_logger.warning(
"No valid name in runtime env plugin %s", plugin_class
)
continue
if plugin_class.name in self.plugins:
default_logger.warning(
"The name of runtime env plugin %s conflicts with %s",
plugin_class,
self.plugins[plugin_class.name],
)
continue
self.plugins[plugin_class.name] = plugin_class()
def get_plugin(self, name: str):
return self.plugins.get(name)
@DeveloperAPI @DeveloperAPI
class PluginCacheManager: class PluginCacheManager:
"""Manages a plugin and a cache for its local resources.""" """Manages a plugin and a cache for its local resources."""

View file

@ -10,9 +10,7 @@ import ray
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
from ray._private.runtime_env.conda import get_uri as get_conda_uri from ray._private.runtime_env.conda import get_uri as get_conda_uri
from ray._private.runtime_env.pip import get_uri as get_pip_uri from ray._private.runtime_env.pip import get_uri as get_pip_uri
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
from ray._private.utils import import_attr
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
from ray.core.generated.runtime_env_common_pb2 import ( from ray.core.generated.runtime_env_common_pb2 import (
RuntimeEnvConfig as ProtoRuntimeEnvConfig, RuntimeEnvConfig as ProtoRuntimeEnvConfig,
@ -85,9 +83,8 @@ def _parse_proto_plugin_runtime_env(
): ):
"""Parse plugin runtime env protobuf to runtime env dict.""" """Parse plugin runtime env protobuf to runtime env dict."""
if runtime_env.python_runtime_env.HasField("plugin_runtime_env"): if runtime_env.python_runtime_env.HasField("plugin_runtime_env"):
runtime_env_dict["plugins"] = dict()
for plugin in runtime_env.python_runtime_env.plugin_runtime_env.plugins: for plugin in runtime_env.python_runtime_env.plugin_runtime_env.plugins:
runtime_env_dict["plugins"][plugin.class_path] = json.loads(plugin.config) runtime_env_dict[plugin.class_path] = json.loads(plugin.config)
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
@ -319,8 +316,12 @@ class RuntimeEnv(dict):
"_ray_release", "_ray_release",
"_ray_commit", "_ray_commit",
"_inject_current_ray", "_inject_current_ray",
"plugins",
"config", "config",
# TODO(SongGuyang): We add this because the test
# `test_experimental_package_github` set a `docker`
# field which is not supported. We should remove it
# with the test.
"docker",
} }
extensions_fields: Set[str] = { extensions_fields: Set[str] = {
@ -363,19 +364,20 @@ class RuntimeEnv(dict):
if runtime_env.get("java_jars"): if runtime_env.get("java_jars"):
runtime_env["java_jars"] = runtime_env.get("java_jars") runtime_env["java_jars"] = runtime_env.get("java_jars")
self.update(runtime_env)
# Blindly trust that the runtime_env has already been validated. # Blindly trust that the runtime_env has already been validated.
# This is dangerous and should only be used internally (e.g., on the # This is dangerous and should only be used internally (e.g., on the
# deserialization codepath. # deserialization codepath.
if not _validate: if not _validate:
self.update(runtime_env)
return return
if runtime_env.get("conda") and runtime_env.get("pip"): if self.get("conda") and self.get("pip"):
raise ValueError( raise ValueError(
"The 'pip' field and 'conda' field of " "The 'pip' field and 'conda' field of "
"runtime_env cannot both be specified.\n" "runtime_env cannot both be specified.\n"
f"specified pip field: {runtime_env['pip']}\n" f"specified pip field: {self['pip']}\n"
f"specified conda field: {runtime_env['conda']}\n" f"specified conda field: {self['conda']}\n"
"To use pip with conda, please only set the 'conda' " "To use pip with conda, please only set the 'conda' "
"field, and specify your pip dependencies " "field, and specify your pip dependencies "
"within the conda YAML config dict: see " "within the conda YAML config dict: see "
@ -385,48 +387,21 @@ class RuntimeEnv(dict):
) )
for option, validate_fn in OPTION_TO_VALIDATION_FN.items(): for option, validate_fn in OPTION_TO_VALIDATION_FN.items():
option_val = runtime_env.get(option) option_val = self.get(option)
if option_val is not None: if option_val is not None:
del self[option]
self[option] = option_val self[option] = option_val
if "_ray_release" in runtime_env: if "_ray_commit" not in self:
self["_ray_release"] = runtime_env["_ray_release"]
if "_ray_commit" in runtime_env:
self["_ray_commit"] = runtime_env["_ray_commit"]
else:
if self.get("pip") or self.get("conda"): if self.get("pip") or self.get("conda"):
self["_ray_commit"] = ray.__commit__ self["_ray_commit"] = ray.__commit__
# Used for testing wheels that have not yet been merged into master. # Used for testing wheels that have not yet been merged into master.
# If this is set to True, then we do not inject Ray into the conda # If this is set to True, then we do not inject Ray into the conda
# or pip dependencies. # or pip dependencies.
if "_inject_current_ray" in runtime_env: if "_inject_current_ray" not in self:
self["_inject_current_ray"] = runtime_env["_inject_current_ray"] if "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ: self["_inject_current_ray"] = True
self["_inject_current_ray"] = True
if "plugins" in runtime_env:
self["plugins"] = dict()
for class_path, plugin_field in runtime_env["plugins"].items():
plugin_class: RuntimeEnvPlugin = import_attr(class_path)
if not issubclass(plugin_class, RuntimeEnvPlugin):
# TODO(simon): move the inferface to public once ready.
raise TypeError(
f"{class_path} must be inherit from "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
)
# TODO(simon): implement uri support.
_ = plugin_class.validate(runtime_env)
# Validation passed, add the entry to parsed runtime env.
self["plugins"][class_path] = plugin_field
unknown_fields = set(runtime_env.keys()) - RuntimeEnv.known_fields
if len(unknown_fields):
logger.warning(
"The following unknown entries in the runtime_env dictionary "
f"will be ignored: {unknown_fields}. If you intended to use "
"them as plugins, they must be nested in the `plugins` field."
)
# NOTE(architkulkarni): This allows worker caching code in C++ to check # NOTE(architkulkarni): This allows worker caching code in C++ to check
# if a runtime env is empty without deserializing it. This is a catch- # if a runtime env is empty without deserializing it. This is a catch-
@ -455,20 +430,17 @@ class RuntimeEnv(dict):
return plugin_uris return plugin_uris
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
if key not in RuntimeEnv.known_fields: # TODO(SongGuyang): Validate the schemas of plugins by json schema.
logger.warning(
"The following unknown entries in the runtime_env dictionary "
f"will be ignored: {key}. If you intended to use "
"them as plugins, they must be nested in the `plugins` field."
)
return
res_value = value res_value = value
if key in OPTION_TO_VALIDATION_FN: if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
res_value = OPTION_TO_VALIDATION_FN[key](value) res_value = OPTION_TO_VALIDATION_FN[key](value)
if res_value is None: if res_value is None:
return return
return super().__setitem__(key, res_value) return super().__setitem__(key, res_value)
def set(self, name: str, value: Any) -> None:
self.__setitem__(name, value)
@classmethod @classmethod
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821 def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv()) proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
@ -655,16 +627,11 @@ class RuntimeEnv(dict):
return None return None
return self["container"].get("run_options", []) return self["container"].get("run_options", [])
def has_plugins(self) -> bool:
if self.get("plugins"):
return True
return False
def plugins(self) -> List[Tuple[str, str]]: def plugins(self) -> List[Tuple[str, str]]:
result = list() result = list()
if self.has_plugins(): for key, value in self.items():
for class_path, plugin_field in self["plugins"].items(): if key not in self.known_fields:
result.append((class_path, json.dumps(plugin_field, sort_keys=True))) result.append((key, json.dumps(value, sort_keys=True)))
return result return result
def _build_proto_pip_runtime_env(self, runtime_env: ProtoRuntimeEnv): def _build_proto_pip_runtime_env(self, runtime_env: ProtoRuntimeEnv):
@ -720,8 +687,7 @@ class RuntimeEnv(dict):
def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv): def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv):
"""Construct plugin runtime env protobuf from runtime env dict.""" """Construct plugin runtime env protobuf from runtime env dict."""
if self.has_plugins(): for class_path, plugin_field in self.plugins():
for class_path, plugin_field in self.plugins(): plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add() plugin.class_path = class_path
plugin.class_path = class_path plugin.config = plugin_field
plugin.config = plugin_field

View file

@ -923,3 +923,13 @@ def start_http_proxy(request):
if proxy: if proxy:
proxy.terminate() proxy.terminate()
proxy.wait() proxy.wait()
@pytest.fixture
def set_runtime_env_plugins(request):
runtime_env_plugins = getattr(request, "param", "0")
try:
os.environ["RAY_RUNTIME_ENV_PLUGINS"] = runtime_env_plugins
yield runtime_env_plugins
finally:
del os.environ["RAY_RUNTIME_ENV_PLUGINS"]

View file

@ -19,9 +19,13 @@ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH = ( MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH = (
"ray.tests.test_placement_group_4.MockWorkerStartupSlowlyPlugin" # noqa "ray.tests.test_placement_group_4.MockWorkerStartupSlowlyPlugin" # noqa
) )
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME = "MockWorkerStartupSlowlyPlugin"
class MockWorkerStartupSlowlyPlugin(RuntimeEnvPlugin): class MockWorkerStartupSlowlyPlugin(RuntimeEnvPlugin):
name = MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME
def validate(runtime_env_dict: dict) -> str: def validate(runtime_env_dict: dict) -> str:
return "success" return "success"
@ -110,7 +114,16 @@ def test_remove_placement_group(ray_start_cluster, connect_to_client):
ray.get(task_ref) ray.get(task_ref)
def test_remove_placement_group_worker_startup_slowly(ray_start_cluster): @pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH,
],
indirect=True,
)
def test_remove_placement_group_worker_startup_slowly(
set_runtime_env_plugins, ray_start_cluster
):
cluster = ray_start_cluster cluster = ray_start_cluster
cluster.add_node(num_cpus=4) cluster.add_node(num_cpus=4)
ray.init(address=cluster.address) ray.init(address=cluster.address)
@ -134,7 +147,7 @@ def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
# runtime env to mock worker start up slowly. # runtime env to mock worker start up slowly.
task_ref = long_running_task.options( task_ref = long_running_task.options(
placement_group=placement_group, placement_group=placement_group,
runtime_env={"plugins": {MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH: {}}}, runtime_env={MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME: {}},
).remote() ).remote()
a = A.options(placement_group=placement_group).remote() a = A.options(placement_group=placement_group).remote()
assert ray.get(a.f.remote()) == 3 assert ray.get(a.f.remote()) == 3

View file

@ -559,6 +559,7 @@ def test_to_make_ensure_runtime_env_api(start_cluster):
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin" MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin"
MY_PLUGIN_NAME = "MyPlugin"
success_retry_number = 3 success_retry_number = 3
runtime_env_retry_times = 0 runtime_env_retry_times = 0
@ -566,9 +567,12 @@ runtime_env_retry_times = 0
# This plugin can make runtime env creation failed before the retry times # This plugin can make runtime env creation failed before the retry times
# increased to `success_retry_number`. # increased to `success_retry_number`.
class MyPlugin(RuntimeEnvPlugin): class MyPlugin(RuntimeEnvPlugin):
name = MY_PLUGIN_NAME
@staticmethod @staticmethod
def validate(runtime_env_dict: dict) -> str: def validate(runtime_env_dict: dict) -> str:
return runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] return runtime_env_dict[MY_PLUGIN_NAME]
@staticmethod @staticmethod
def modify_context( def modify_context(
@ -592,7 +596,16 @@ class MyPlugin(RuntimeEnvPlugin):
], ],
indirect=True, indirect=True,
) )
def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular): @pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_CLASS_PATH,
],
indirect=True,
)
def test_runtime_env_retry(
set_runtime_env_retry_times, set_runtime_env_plugins, ray_start_regular
):
@ray.remote @ray.remote
def f(): def f():
return "ok" return "ok"
@ -601,9 +614,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
if runtime_env_retry_times >= success_retry_number: if runtime_env_retry_times >= success_retry_number:
# Enough retry times # Enough retry times
output = ray.get( output = ray.get(
f.options( f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote()
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
).remote()
) )
assert output == "ok" assert output == "ok"
else: else:
@ -611,11 +622,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
with pytest.raises( with pytest.raises(
RuntimeEnvSetupError, match=f"Fault injection {runtime_env_retry_times}" RuntimeEnvSetupError, match=f"Fault injection {runtime_env_retry_times}"
): ):
ray.get( ray.get(f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote())
f.options(
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
).remote()
)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -13,14 +13,16 @@ from ray._private.test_utils import test_external_redis, wait_for_condition
from ray.exceptions import RuntimeEnvSetupError from ray.exceptions import RuntimeEnvSetupError
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin" MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
MY_PLUGIN_NAME = "MyPlugin"
class MyPlugin(RuntimeEnvPlugin): class MyPlugin(RuntimeEnvPlugin):
name = MY_PLUGIN_NAME
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY" env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
@staticmethod @staticmethod
def validate(runtime_env_dict: dict) -> str: def validate(runtime_env_dict: dict) -> str:
value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH] value = runtime_env_dict[MY_PLUGIN_NAME]
if value == "fail": if value == "fail":
raise ValueError("not allowed") raise ValueError("not allowed")
return value return value
@ -42,7 +44,14 @@ class MyPlugin(RuntimeEnvPlugin):
) )
def test_simple_env_modification_plugin(ray_start_regular): @pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_CLASS_PATH,
],
indirect=True,
)
def test_simple_env_modification_plugin(set_runtime_env_plugins, ray_start_regular):
_, tmp_file_path = tempfile.mkstemp() _, tmp_file_path = tempfile.mkstemp()
@ray.remote @ray.remote
@ -57,21 +66,19 @@ def test_simple_env_modification_plugin(ray_start_regular):
"nice": psutil.Process().nice(), "nice": psutil.Process().nice(),
} }
with pytest.raises(ValueError, match="not allowed"): with pytest.raises(RuntimeEnvSetupError, match="not allowed"):
f.options(runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: "fail"}}).remote() ray.get(f.options(runtime_env={MY_PLUGIN_NAME: "fail"}).remote())
if os.name != "nt": if os.name != "nt":
output = ray.get( output = ray.get(
f.options( f.options(
runtime_env={ runtime_env={
"plugins": { MY_PLUGIN_NAME: {
MY_PLUGIN_CLASS_PATH: { "env_value": 42,
"env_value": 42, "tmp_file": tmp_file_path,
"tmp_file": tmp_file_path, "tmp_content": "hello",
"tmp_content": "hello", # See https://en.wikipedia.org/wiki/Nice_(Unix)
# See https://en.wikipedia.org/wiki/Nice_(Unix) "prefix_command": "nice -n 19",
"prefix_command": "nice -n 19",
}
} }
} }
).remote() ).remote()
@ -81,11 +88,13 @@ def test_simple_env_modification_plugin(ray_start_regular):
MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang" MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang"
MY_PLUGIN_FOR_HANG_NAME = "MyPluginForHang"
my_plugin_setup_times = 0 my_plugin_setup_times = 0
# This plugin will hang when first setup, second setup will ok # This plugin will hang when first setup, second setup will ok
class MyPluginForHang(RuntimeEnvPlugin): class MyPluginForHang(RuntimeEnvPlugin):
name = MY_PLUGIN_FOR_HANG_NAME
env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY" env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY"
@staticmethod @staticmethod
@ -112,7 +121,14 @@ class MyPluginForHang(RuntimeEnvPlugin):
ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times) ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times)
def test_plugin_hang(ray_start_regular): @pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_FOR_HANG_CLASS_PATH,
],
indirect=True,
)
def test_plugin_hang(set_runtime_env_plugins, ray_start_regular):
env_key = MyPluginForHang.env_key env_key = MyPluginForHang.env_key
@ray.remote(num_cpus=0.1) @ray.remote(num_cpus=0.1)
@ -122,11 +138,9 @@ def test_plugin_hang(ray_start_regular):
refs = [ refs = [
f.options( f.options(
# Avoid hitting the cache of runtime_env # Avoid hitting the cache of runtime_env
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f1"}}} runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f1"}}
).remote(),
f.options(
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f2"}}}
).remote(), ).remote(),
f.options(runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f2"}}).remote(),
] ]
def condition(): def condition():
@ -145,19 +159,26 @@ def test_plugin_hang(ray_start_regular):
DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin" DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
DUMMY_PLUGIN_NAME = "DummyPlugin"
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin" HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
HANG_PLUGIN_NAME = "HangPlugin"
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = ( DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin" "ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
) )
DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout"
class DummyPlugin(RuntimeEnvPlugin): class DummyPlugin(RuntimeEnvPlugin):
name = DUMMY_PLUGIN_NAME
@staticmethod @staticmethod
def validate(runtime_env_dict: dict) -> str: def validate(runtime_env_dict: dict) -> str:
return 1 return 1
class HangPlugin(DummyPlugin): class HangPlugin(DummyPlugin):
name = HANG_PLUGIN_NAME
def create( def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821 self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
) -> float: ) -> float:
@ -165,14 +186,25 @@ class HangPlugin(DummyPlugin):
class DiasbleTimeoutPlugin(DummyPlugin): class DiasbleTimeoutPlugin(DummyPlugin):
name = DISABLE_TIMEOUT_PLUGIN_NAME
def create( def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821 self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
) -> float: ) -> float:
sleep(10) sleep(10)
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
f"{DUMMY_PLUGIN_CLASS_PATH},"
f"{HANG_PLUGIN_CLASS_PATH},"
f"{DISABLE_TIMEOUT_PLUGIN_CLASS_PATH}",
],
indirect=True,
)
@pytest.mark.skipif(test_external_redis(), reason="Failing in redis mode.") @pytest.mark.skipif(test_external_redis(), reason="Failing in redis mode.")
def test_plugin_timeout(start_cluster): def test_plugin_timeout(set_runtime_env_plugins, start_cluster):
@ray.remote(num_cpus=0.1) @ray.remote(num_cpus=0.1)
def f(): def f():
return True return True
@ -180,20 +212,14 @@ def test_plugin_timeout(start_cluster):
refs = [ refs = [
f.options( f.options(
runtime_env={ runtime_env={
"plugins": { HANG_PLUGIN_NAME: {"name": "f1"},
HANG_PLUGIN_CLASS_PATH: {"name": "f1"},
},
"config": {"setup_timeout_seconds": 10}, "config": {"setup_timeout_seconds": 10},
} }
).remote(), ).remote(),
f.options( f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(),
runtime_env={"plugins": {DUMMY_PLUGIN_CLASS_PATH: {"name": "f2"}}}
).remote(),
f.options( f.options(
runtime_env={ runtime_env={
"plugins": { HANG_PLUGIN_NAME: {"name": "f3"},
HANG_PLUGIN_CLASS_PATH: {"name": "f3"},
},
"config": {"setup_timeout_seconds": -1}, "config": {"setup_timeout_seconds": -1},
} }
).remote(), ).remote(),