[runtime env] plugin refactor[1/n] (#26077)

This commit is contained in:
Guyang Song 2022-06-28 14:09:05 +08:00 committed by GitHub
parent 876fef0fcf
commit 58bfad84d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 178 additions and 110 deletions

View file

@ -20,10 +20,10 @@ from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.java_jars import JavaJarsPlugin
from ray._private.runtime_env.pip import PipPlugin
from ray._private.runtime_env.plugin import PluginCacheManager
from ray._private.runtime_env.plugin import RuntimeEnvPluginManager
from ray._private.runtime_env.py_modules import PyModulesPlugin
from ray._private.runtime_env.uri_cache import URICache
from ray._private.runtime_env.working_dir import WorkingDirPlugin
from ray._private.utils import import_attr
from ray.core.generated import (
agent_manager_pb2,
runtime_env_agent_pb2,
@ -224,6 +224,7 @@ class RuntimeEnvAgent(
self.unused_uris_processor,
self.unused_runtime_env_processor,
)
self._runtime_env_plugin_manager = RuntimeEnvPluginManager()
self._logger = default_logger
@ -295,13 +296,13 @@ class RuntimeEnvAgent(
def setup_plugins():
# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
per_job_logger.debug(
f"Setting up runtime env plugin {plugin_class_path}"
)
plugin_class = import_attr(plugin_class_path)
plugin = plugin_class()
for name, config in runtime_env.plugins():
per_job_logger.debug(f"Setting up runtime env plugin {name}")
plugin = self._runtime_env_plugin_manager.get_plugin(name)
if plugin is None:
raise RuntimeError(f"runtime env plugin {name} not found.")
# TODO(architkulkarni): implement uri support
plugin.validate(runtime_env)
plugin.create("uri not implemented", json.loads(config), context)
plugin.modify_context(
"uri not implemented",

View file

@ -1,2 +1,5 @@
# Env var set by job manager to pass runtime env and metadata to subprocess
RAY_JOB_CONFIG_JSON_ENV_VAR = "RAY_JOB_CONFIG_JSON_ENV_VAR"
# The plugins which should be loaded when ray cluster starts.
RAY_RUNTIME_ENV_PLUGINS_ENV_VAR = "RAY_RUNTIME_ENV_PLUGINS"

View file

@ -1,10 +1,13 @@
import logging
import os
from abc import ABC
from typing import List
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.uri_cache import URICache
from ray._private.runtime_env.constants import RAY_RUNTIME_ENV_PLUGINS_ENV_VAR
from ray.util.annotations import DeveloperAPI
from ray._private.utils import import_attr
default_logger = logging.getLogger(__name__)
@ -13,7 +16,7 @@ default_logger = logging.getLogger(__name__)
class RuntimeEnvPlugin(ABC):
"""Abstract base class for runtime environment plugins."""
name: str
name: str = None
@staticmethod
def validate(runtime_env_dict: dict) -> str:
@ -88,6 +91,45 @@ class RuntimeEnvPlugin(ABC):
return 0
class RuntimeEnvPluginManager:
"""This mananger is used to load plugins in runtime env agent."""
def __init__(self):
self.plugins = {}
plugins_config = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
if plugins_config:
self.load_plugins(plugins_config.split(","))
def load_plugins(self, plugin_classes: List[str]):
"""Load runtime env plugins"""
for plugin_class_path in plugin_classes:
plugin_class = import_attr(plugin_class_path)
if not issubclass(plugin_class, RuntimeEnvPlugin):
default_logger.warning(
"Invalid runtime env plugin class %s. "
"The plugin class must inherit "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin.",
plugin_class,
)
continue
if not plugin_class.name:
default_logger.warning(
"No valid name in runtime env plugin %s", plugin_class
)
continue
if plugin_class.name in self.plugins:
default_logger.warning(
"The name of runtime env plugin %s conflicts with %s",
plugin_class,
self.plugins[plugin_class.name],
)
continue
self.plugins[plugin_class.name] = plugin_class()
def get_plugin(self, name: str):
return self.plugins.get(name)
@DeveloperAPI
class PluginCacheManager:
"""Manages a plugin and a cache for its local resources."""

View file

@ -10,9 +10,7 @@ import ray
from ray._private.ray_constants import DEFAULT_RUNTIME_ENV_TIMEOUT_SECONDS
from ray._private.runtime_env.conda import get_uri as get_conda_uri
from ray._private.runtime_env.pip import get_uri as get_pip_uri
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
from ray._private.utils import import_attr
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
from ray.core.generated.runtime_env_common_pb2 import (
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
@ -85,9 +83,8 @@ def _parse_proto_plugin_runtime_env(
):
"""Parse plugin runtime env protobuf to runtime env dict."""
if runtime_env.python_runtime_env.HasField("plugin_runtime_env"):
runtime_env_dict["plugins"] = dict()
for plugin in runtime_env.python_runtime_env.plugin_runtime_env.plugins:
runtime_env_dict["plugins"][plugin.class_path] = json.loads(plugin.config)
runtime_env_dict[plugin.class_path] = json.loads(plugin.config)
@PublicAPI(stability="beta")
@ -319,8 +316,12 @@ class RuntimeEnv(dict):
"_ray_release",
"_ray_commit",
"_inject_current_ray",
"plugins",
"config",
# TODO(SongGuyang): We add this because the test
# `test_experimental_package_github` set a `docker`
# field which is not supported. We should remove it
# with the test.
"docker",
}
extensions_fields: Set[str] = {
@ -363,19 +364,20 @@ class RuntimeEnv(dict):
if runtime_env.get("java_jars"):
runtime_env["java_jars"] = runtime_env.get("java_jars")
self.update(runtime_env)
# Blindly trust that the runtime_env has already been validated.
# This is dangerous and should only be used internally (e.g., on the
# deserialization codepath.
if not _validate:
self.update(runtime_env)
return
if runtime_env.get("conda") and runtime_env.get("pip"):
if self.get("conda") and self.get("pip"):
raise ValueError(
"The 'pip' field and 'conda' field of "
"runtime_env cannot both be specified.\n"
f"specified pip field: {runtime_env['pip']}\n"
f"specified conda field: {runtime_env['conda']}\n"
f"specified pip field: {self['pip']}\n"
f"specified conda field: {self['conda']}\n"
"To use pip with conda, please only set the 'conda' "
"field, and specify your pip dependencies "
"within the conda YAML config dict: see "
@ -385,48 +387,21 @@ class RuntimeEnv(dict):
)
for option, validate_fn in OPTION_TO_VALIDATION_FN.items():
option_val = runtime_env.get(option)
option_val = self.get(option)
if option_val is not None:
del self[option]
self[option] = option_val
if "_ray_release" in runtime_env:
self["_ray_release"] = runtime_env["_ray_release"]
if "_ray_commit" in runtime_env:
self["_ray_commit"] = runtime_env["_ray_commit"]
else:
if "_ray_commit" not in self:
if self.get("pip") or self.get("conda"):
self["_ray_commit"] = ray.__commit__
# Used for testing wheels that have not yet been merged into master.
# If this is set to True, then we do not inject Ray into the conda
# or pip dependencies.
if "_inject_current_ray" in runtime_env:
self["_inject_current_ray"] = runtime_env["_inject_current_ray"]
elif "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
self["_inject_current_ray"] = True
if "plugins" in runtime_env:
self["plugins"] = dict()
for class_path, plugin_field in runtime_env["plugins"].items():
plugin_class: RuntimeEnvPlugin = import_attr(class_path)
if not issubclass(plugin_class, RuntimeEnvPlugin):
# TODO(simon): move the inferface to public once ready.
raise TypeError(
f"{class_path} must be inherit from "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin."
)
# TODO(simon): implement uri support.
_ = plugin_class.validate(runtime_env)
# Validation passed, add the entry to parsed runtime env.
self["plugins"][class_path] = plugin_field
unknown_fields = set(runtime_env.keys()) - RuntimeEnv.known_fields
if len(unknown_fields):
logger.warning(
"The following unknown entries in the runtime_env dictionary "
f"will be ignored: {unknown_fields}. If you intended to use "
"them as plugins, they must be nested in the `plugins` field."
)
if "_inject_current_ray" not in self:
if "RAY_RUNTIME_ENV_LOCAL_DEV_MODE" in os.environ:
self["_inject_current_ray"] = True
# NOTE(architkulkarni): This allows worker caching code in C++ to check
# if a runtime env is empty without deserializing it. This is a catch-
@ -455,20 +430,17 @@ class RuntimeEnv(dict):
return plugin_uris
def __setitem__(self, key: str, value: Any) -> None:
if key not in RuntimeEnv.known_fields:
logger.warning(
"The following unknown entries in the runtime_env dictionary "
f"will be ignored: {key}. If you intended to use "
"them as plugins, they must be nested in the `plugins` field."
)
return
# TODO(SongGuyang): Validate the schemas of plugins by json schema.
res_value = value
if key in OPTION_TO_VALIDATION_FN:
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
res_value = OPTION_TO_VALIDATION_FN[key](value)
if res_value is None:
return
return super().__setitem__(key, res_value)
def set(self, name: str, value: Any) -> None:
self.__setitem__(name, value)
@classmethod
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
@ -655,16 +627,11 @@ class RuntimeEnv(dict):
return None
return self["container"].get("run_options", [])
def has_plugins(self) -> bool:
if self.get("plugins"):
return True
return False
def plugins(self) -> List[Tuple[str, str]]:
result = list()
if self.has_plugins():
for class_path, plugin_field in self["plugins"].items():
result.append((class_path, json.dumps(plugin_field, sort_keys=True)))
for key, value in self.items():
if key not in self.known_fields:
result.append((key, json.dumps(value, sort_keys=True)))
return result
def _build_proto_pip_runtime_env(self, runtime_env: ProtoRuntimeEnv):
@ -720,8 +687,7 @@ class RuntimeEnv(dict):
def _build_proto_plugin_runtime_env(self, runtime_env: ProtoRuntimeEnv):
"""Construct plugin runtime env protobuf from runtime env dict."""
if self.has_plugins():
for class_path, plugin_field in self.plugins():
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
plugin.class_path = class_path
plugin.config = plugin_field
for class_path, plugin_field in self.plugins():
plugin = runtime_env.python_runtime_env.plugin_runtime_env.plugins.add()
plugin.class_path = class_path
plugin.config = plugin_field

View file

@ -923,3 +923,13 @@ def start_http_proxy(request):
if proxy:
proxy.terminate()
proxy.wait()
@pytest.fixture
def set_runtime_env_plugins(request):
runtime_env_plugins = getattr(request, "param", "0")
try:
os.environ["RAY_RUNTIME_ENV_PLUGINS"] = runtime_env_plugins
yield runtime_env_plugins
finally:
del os.environ["RAY_RUNTIME_ENV_PLUGINS"]

View file

@ -19,9 +19,13 @@ from ray._private.runtime_env.plugin import RuntimeEnvPlugin
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH = (
"ray.tests.test_placement_group_4.MockWorkerStartupSlowlyPlugin" # noqa
)
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME = "MockWorkerStartupSlowlyPlugin"
class MockWorkerStartupSlowlyPlugin(RuntimeEnvPlugin):
name = MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME
def validate(runtime_env_dict: dict) -> str:
return "success"
@ -110,7 +114,16 @@ def test_remove_placement_group(ray_start_cluster, connect_to_client):
ray.get(task_ref)
def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH,
],
indirect=True,
)
def test_remove_placement_group_worker_startup_slowly(
set_runtime_env_plugins, ray_start_cluster
):
cluster = ray_start_cluster
cluster.add_node(num_cpus=4)
ray.init(address=cluster.address)
@ -134,7 +147,7 @@ def test_remove_placement_group_worker_startup_slowly(ray_start_cluster):
# runtime env to mock worker start up slowly.
task_ref = long_running_task.options(
placement_group=placement_group,
runtime_env={"plugins": {MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_CLASS_PATH: {}}},
runtime_env={MOCK_WORKER_STARTUP_SLOWLY_PLUGIN_NAME: {}},
).remote()
a = A.options(placement_group=placement_group).remote()
assert ray.get(a.f.remote()) == 3

View file

@ -559,6 +559,7 @@ def test_to_make_ensure_runtime_env_api(start_cluster):
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env.MyPlugin"
MY_PLUGIN_NAME = "MyPlugin"
success_retry_number = 3
runtime_env_retry_times = 0
@ -566,9 +567,12 @@ runtime_env_retry_times = 0
# This plugin can make runtime env creation failed before the retry times
# increased to `success_retry_number`.
class MyPlugin(RuntimeEnvPlugin):
name = MY_PLUGIN_NAME
@staticmethod
def validate(runtime_env_dict: dict) -> str:
return runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
return runtime_env_dict[MY_PLUGIN_NAME]
@staticmethod
def modify_context(
@ -592,7 +596,16 @@ class MyPlugin(RuntimeEnvPlugin):
],
indirect=True,
)
def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_CLASS_PATH,
],
indirect=True,
)
def test_runtime_env_retry(
set_runtime_env_retry_times, set_runtime_env_plugins, ray_start_regular
):
@ray.remote
def f():
return "ok"
@ -601,9 +614,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
if runtime_env_retry_times >= success_retry_number:
# Enough retry times
output = ray.get(
f.options(
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
).remote()
f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote()
)
assert output == "ok"
else:
@ -611,11 +622,7 @@ def test_runtime_env_retry(set_runtime_env_retry_times, ray_start_regular):
with pytest.raises(
RuntimeEnvSetupError, match=f"Fault injection {runtime_env_retry_times}"
):
ray.get(
f.options(
runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: {"key": "value"}}}
).remote()
)
ray.get(f.options(runtime_env={MY_PLUGIN_NAME: {"key": "value"}}).remote())
@pytest.mark.parametrize(

View file

@ -13,14 +13,16 @@ from ray._private.test_utils import test_external_redis, wait_for_condition
from ray.exceptions import RuntimeEnvSetupError
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
MY_PLUGIN_NAME = "MyPlugin"
class MyPlugin(RuntimeEnvPlugin):
name = MY_PLUGIN_NAME
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
@staticmethod
def validate(runtime_env_dict: dict) -> str:
value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
value = runtime_env_dict[MY_PLUGIN_NAME]
if value == "fail":
raise ValueError("not allowed")
return value
@ -42,7 +44,14 @@ class MyPlugin(RuntimeEnvPlugin):
)
def test_simple_env_modification_plugin(ray_start_regular):
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_CLASS_PATH,
],
indirect=True,
)
def test_simple_env_modification_plugin(set_runtime_env_plugins, ray_start_regular):
_, tmp_file_path = tempfile.mkstemp()
@ray.remote
@ -57,21 +66,19 @@ def test_simple_env_modification_plugin(ray_start_regular):
"nice": psutil.Process().nice(),
}
with pytest.raises(ValueError, match="not allowed"):
f.options(runtime_env={"plugins": {MY_PLUGIN_CLASS_PATH: "fail"}}).remote()
with pytest.raises(RuntimeEnvSetupError, match="not allowed"):
ray.get(f.options(runtime_env={MY_PLUGIN_NAME: "fail"}).remote())
if os.name != "nt":
output = ray.get(
f.options(
runtime_env={
"plugins": {
MY_PLUGIN_CLASS_PATH: {
"env_value": 42,
"tmp_file": tmp_file_path,
"tmp_content": "hello",
# See https://en.wikipedia.org/wiki/Nice_(Unix)
"prefix_command": "nice -n 19",
}
MY_PLUGIN_NAME: {
"env_value": 42,
"tmp_file": tmp_file_path,
"tmp_content": "hello",
# See https://en.wikipedia.org/wiki/Nice_(Unix)
"prefix_command": "nice -n 19",
}
}
).remote()
@ -81,11 +88,13 @@ def test_simple_env_modification_plugin(ray_start_regular):
MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang"
MY_PLUGIN_FOR_HANG_NAME = "MyPluginForHang"
my_plugin_setup_times = 0
# This plugin will hang when first setup, second setup will ok
class MyPluginForHang(RuntimeEnvPlugin):
name = MY_PLUGIN_FOR_HANG_NAME
env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY"
@staticmethod
@ -112,7 +121,14 @@ class MyPluginForHang(RuntimeEnvPlugin):
ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times)
def test_plugin_hang(ray_start_regular):
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
MY_PLUGIN_FOR_HANG_CLASS_PATH,
],
indirect=True,
)
def test_plugin_hang(set_runtime_env_plugins, ray_start_regular):
env_key = MyPluginForHang.env_key
@ray.remote(num_cpus=0.1)
@ -122,11 +138,9 @@ def test_plugin_hang(ray_start_regular):
refs = [
f.options(
# Avoid hitting the cache of runtime_env
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f1"}}}
).remote(),
f.options(
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f2"}}}
runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f1"}}
).remote(),
f.options(runtime_env={MY_PLUGIN_FOR_HANG_NAME: {"name": "f2"}}).remote(),
]
def condition():
@ -145,19 +159,26 @@ def test_plugin_hang(ray_start_regular):
DUMMY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.DummyPlugin"
DUMMY_PLUGIN_NAME = "DummyPlugin"
HANG_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.HangPlugin"
HANG_PLUGIN_NAME = "HangPlugin"
DISABLE_TIMEOUT_PLUGIN_CLASS_PATH = (
"ray.tests.test_runtime_env_plugin.DiasbleTimeoutPlugin"
)
DISABLE_TIMEOUT_PLUGIN_NAME = "test_plugin_timeout"
class DummyPlugin(RuntimeEnvPlugin):
name = DUMMY_PLUGIN_NAME
@staticmethod
def validate(runtime_env_dict: dict) -> str:
return 1
class HangPlugin(DummyPlugin):
name = HANG_PLUGIN_NAME
def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
) -> float:
@ -165,14 +186,25 @@ class HangPlugin(DummyPlugin):
class DiasbleTimeoutPlugin(DummyPlugin):
name = DISABLE_TIMEOUT_PLUGIN_NAME
def create(
self, uri: str, runtime_env: "RuntimeEnv", ctx: RuntimeEnvContext # noqa: F821
) -> float:
sleep(10)
@pytest.mark.parametrize(
"set_runtime_env_plugins",
[
f"{DUMMY_PLUGIN_CLASS_PATH},"
f"{HANG_PLUGIN_CLASS_PATH},"
f"{DISABLE_TIMEOUT_PLUGIN_CLASS_PATH}",
],
indirect=True,
)
@pytest.mark.skipif(test_external_redis(), reason="Failing in redis mode.")
def test_plugin_timeout(start_cluster):
def test_plugin_timeout(set_runtime_env_plugins, start_cluster):
@ray.remote(num_cpus=0.1)
def f():
return True
@ -180,20 +212,14 @@ def test_plugin_timeout(start_cluster):
refs = [
f.options(
runtime_env={
"plugins": {
HANG_PLUGIN_CLASS_PATH: {"name": "f1"},
},
HANG_PLUGIN_NAME: {"name": "f1"},
"config": {"setup_timeout_seconds": 10},
}
).remote(),
f.options(
runtime_env={"plugins": {DUMMY_PLUGIN_CLASS_PATH: {"name": "f2"}}}
).remote(),
f.options(runtime_env={DUMMY_PLUGIN_NAME: {"name": "f2"}}).remote(),
f.options(
runtime_env={
"plugins": {
HANG_PLUGIN_CLASS_PATH: {"name": "f3"},
},
HANG_PLUGIN_NAME: {"name": "f3"},
"config": {"setup_timeout_seconds": -1},
}
).remote(),