[Runtime Env] Implement basic runtime env plugin mechanism (#19044)

This commit is contained in:
Simon Mo 2021-10-01 17:22:54 -07:00 committed by GitHub
parent cac6f9d75c
commit 9b2a368c8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 191 additions and 1 deletions

View file

@ -182,6 +182,7 @@ test_python() {
-python/ray/tests:test_ray_init # test_redis_port() seems to fail here, but pass in isolation
-python/ray/tests:test_resource_demand_scheduler
-python/ray/tests:test_reference_counting # too flaky 9/25/21
-python/ray/tests:test_runtime_env_plugin # runtime_env not supported on Windows
-python/ray/tests:test_runtime_env_env_vars # runtime_env not supported on Windows
-python/ray/tests:test_runtime_env_complicated # conda install slow leading to timeout
-python/ray/tests:test_stress # timeout

View file

@ -6,6 +6,7 @@ import logging
import os
import time
from typing import Dict, Set
from ray._private.utils import import_attr
from ray.core.generated import runtime_env_agent_pb2
from ray.core.generated import runtime_env_agent_pb2_grpc
@ -98,6 +99,15 @@ class RuntimeEnvAgent(dashboard_utils.DashboardAgentModule,
self._working_dir_uri_to_envs[uri].add(
serialized_runtime_env)
# Run setup function from all the plugins
for plugin_class_path in runtime_env.get("plugins", {}).keys():
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create("uri not implemented", runtime_env,
context)
plugin_class.modify_context("uri not implemented",
runtime_env, context)
return context
loop = asyncio.get_event_loop()

View file

@ -4,9 +4,12 @@ import os
import sys
from typing import Dict, List, Optional
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@DeveloperAPI
class RuntimeEnvContext:
"""A context used to describe the created runtime env."""

View file

@ -0,0 +1,70 @@
from abc import ABC, abstractstaticmethod
from ray._private.runtime_env.context import RuntimeEnvContext
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
class RuntimeEnvPlugin(ABC):
@abstractstaticmethod
def validate(runtime_env_dict: dict) -> str:
"""Validate user entry and returns a URI uniquely describing resource.
This method will be called at ``f.options(runtime_env=...)`` or
``ray.init(runtime_env=...)`` time and it should check the runtime env
dictionary for any errors. For example, it can raise "TypeError:
expected string for "conda" field".
Args:
runtime_env_dict(dict): the entire dictionary passed in by user.
Returns:
uri(str): a URI uniquely describing this resource (e.g., a hash of
the conda spec).
"""
raise NotImplementedError()
def create(uri: str, runtime_env_dict: dict,
ctx: RuntimeEnvContext) -> float:
"""Create and install the runtime environment.
Gets called in the runtime env agent at install time. The URI can be
used as a caching mechanism.
Args:
uri(str): a URI uniquely describing this resource.
runtime_env_dict(dict): the entire dictionary passed in by user.
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
Returns:
the disk space taken up by this plugin installation for this
environment. e.g. for working_dir, this downloads the files to the
local node.
"""
return 0
def modify_context(uri: str, runtime_env_dict: dict,
ctx: RuntimeEnvContext) -> None:
"""Modify context to change worker startup behavior.
For example, you can use this to preprend "cd <dir>" command to worker
startup, or add new environment variables.
Args:
uri(str): a URI uniquely describing this resource.
runtime_env_dict(dict): the entire dictionary passed in by user.
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
"""
return
def delete(uri: str, ctx: RuntimeEnvContext) -> float:
"""Delete the the runtime environment given uri.
Args:
uri(str): a URI uniquely describing this resource.
ctx(RuntimeEnvContext): auxiliary information supplied by Ray.
Returns:
the amount of space reclaimed by the deletion.
"""
return 0

View file

@ -3,7 +3,9 @@ import logging
import os
from pathlib import Path
import sys
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Set
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.utils import import_attr
import yaml
import ray
@ -64,6 +66,11 @@ class RuntimeEnvDict:
{"OMP_NUM_THREADS": "32", "TF_WARNINGS": "none"}
"""
known_fields: Set[str] = {
"working_dir", "conda", "pip", "uris", "containers", "env_vars",
"_ray_release", "_ray_commit", "_inject_current_ray", "plugins"
}
def __init__(self,
runtime_env_json: dict,
working_dir: Optional[str] = None):
@ -165,6 +172,29 @@ class RuntimeEnvDict:
# TODO(ekl) support py_modules
# TODO(architkulkarni) support docker
if "plugins" in runtime_env_json:
self._dict["plugins"] = dict()
for class_path, plugin_field in runtime_env_json[
"plugins"].items():
plugin_class: RuntimeEnvPlugin = import_attr(class_path)
if not issubclass(plugin_class, RuntimeEnvPlugin):
# TODO(simon): move the inferface to public once ready.
raise TypeError(
f"{class_path} must be inherit from "
"ray._private.runtime_env.plugin.RuntimeEnvPlugin.")
# TODO(simon): implement uri support.
_ = plugin_class.validate(runtime_env_json)
# Validation passed, add the entry to parsed runtime env.
self._dict["plugins"][class_path] = plugin_field
unknown_fields = (
set(runtime_env_json.keys()) - RuntimeEnvDict.known_fields)
if len(unknown_fields):
logger.warning(
"The following unknown entries in the runtime_env dictionary "
f"will be ignored: {unknown_fields}. If you are intended to "
"use plugin, make sure to nest them in the ``plugins`` field.")
# TODO(architkulkarni) This is to make it easy for the worker caching
# code in C++ to check if the env is empty without deserializing and
# parsing it. We should use a less confusing approach here.

View file

@ -82,6 +82,7 @@ py_test_module_list(
"test_reference_counting.py",
"test_resource_demand_scheduler.py",
"test_runtime_env_env_vars.py",
"test_runtime_env_plugin.py",
"test_runtime_env_fork_process.py",
"test_serialization.py",
"test_shuffle.py",

View file

@ -0,0 +1,75 @@
import os
import tempfile
import pytest
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
import ray
MY_PLUGIN_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPlugin"
class MyPlugin(RuntimeEnvPlugin):
env_key = "MY_PLUGIN_TEST_ENVIRONMENT_KEY"
@staticmethod
def validate(runtime_env_dict: dict) -> str:
value = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
if value == "fail":
raise ValueError("not allowed")
return value
@staticmethod
def modify_context(uri: str, runtime_env_dict: dict,
ctx: RuntimeEnvContext) -> None:
plugin_config_dict = runtime_env_dict["plugins"][MY_PLUGIN_CLASS_PATH]
ctx.env_vars[MyPlugin.env_key] = str(plugin_config_dict["env_value"])
ctx.command_prefix.append(
f"echo {plugin_config_dict['tmp_content']} > "
f"{plugin_config_dict['tmp_file']}")
ctx.py_executable = (
plugin_config_dict["prefix_command"] + " " + ctx.py_executable)
def test_simple_env_modification_plugin(ray_start_regular):
_, tmp_file_path = tempfile.mkstemp()
@ray.remote
def f():
import psutil
with open(tmp_file_path, "r") as f:
content = f.read().strip()
return {
"env_value": os.environ[MyPlugin.env_key],
"tmp_content": content,
"nice": psutil.Process().nice(),
}
with pytest.raises(ValueError, match="not allowed"):
f.options(runtime_env={
"plugins": {
MY_PLUGIN_CLASS_PATH: "fail"
}
}).remote()
output = ray.get(
f.options(
runtime_env={
"plugins": {
MY_PLUGIN_CLASS_PATH: {
"env_value": 42,
"tmp_file": tmp_file_path,
"tmp_content": "hello",
# See https://en.wikipedia.org/wiki/Nice_(Unix)
"prefix_command": "nice -n 19",
}
}
}).remote())
assert output == {"env_value": "42", "tmp_content": "hello", "nice": 19}
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-sv", __file__]))