[Serve] Make REST API deployments inherit top-level runtime_env (#25502)

This commit is contained in:
shrekris-anyscale 2022-06-08 15:58:00 -07:00 committed by GitHub
parent 7616435ed0
commit f3c2bd6718
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 302 additions and 21 deletions

View file

@ -155,18 +155,22 @@ def test_put_get_success(ray_start_stop):
def test_put_new_rest_api(ray_start_stop):
config = {
"import_path": "ray.serve.tests.test_config_files.pizza.serve_dag",
"import_path": "conditional_dag.serve_dag",
"runtime_env": {
"working_dir": (
"https://github.com/ray-project/test_dag/archive/"
"cc246509ba3c9371f8450f74fdc18018428630bd.zip"
)
},
"deployments": [
{
"name": "Multiplier",
"user_config": {
"factor": 1,
},
"user_config": {"factor": 1},
},
{
"name": "Adder",
"user_config": {
"increment": 1,
"ray_actor_options": {
"runtime_env": {"env_vars": {"override_increment": "1"}}
},
},
],
@ -177,12 +181,24 @@ def test_put_new_rest_api(ray_start_stop):
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "3 pizzas please!",
timeout=30,
timeout=15,
)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["MUL", 2]).json()
== "2 pizzas please!",
timeout=30,
== "-4 pizzas please!",
timeout=15,
)
# Make Adder's ray_actor_options an empty dictionary.
config["deployments"][1]["ray_actor_options"] = {}
# Check that Adder's empty config ray_actor_options override its code options
put_response = requests.put(GET_OR_PUT_URL, json=config, timeout=30)
assert put_response.status_code == 200
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "4 pizzas please!",
timeout=15,
)

View file

@ -39,6 +39,7 @@ from ray.serve.logging_utils import configure_component_logger
from ray.serve.long_poll import LongPollHost
from ray.serve.storage.checkpoint_path import make_kv_store
from ray.serve.storage.kv_store import RayInternalKVStore
from ray.serve.utils import override_runtime_envs_except_env_vars
logger = logging.getLogger(SERVE_LOGGER_NAME)
@ -423,7 +424,7 @@ class ServeController:
def deploy_app(
self,
import_path: str,
runtime_env: str,
runtime_env: Dict,
deployment_override_options: List[Dict],
) -> None:
"""Kicks off a task that deploys a Serve application.
@ -446,7 +447,7 @@ class ServeController:
self.config_deployment_request_ref = run_graph.options(
runtime_env=runtime_env
).remote(import_path, deployment_override_options)
).remote(import_path, runtime_env, deployment_override_options)
self.deployment_timestamp = time.time()
@ -570,7 +571,9 @@ class ServeController:
@ray.remote(max_calls=1)
def run_graph(import_path: str, deployment_override_options: List[Dict]):
def run_graph(
import_path: str, graph_env: dict, deployment_override_options: List[Dict]
):
"""Deploys a Serve application to the controller's Ray cluster."""
from ray import serve
from ray.serve.api import build
@ -580,9 +583,26 @@ def run_graph(import_path: str, deployment_override_options: List[Dict]):
app = build(graph)
# Override options for each deployment
for options_dict in deployment_override_options:
name = options_dict["name"]
app.deployments[name].set_options(**options_dict)
for options in deployment_override_options:
name = options["name"]
# Merge graph-level and deployment-level runtime_envs
if "ray_actor_options" in options:
# If specified, get ray_actor_options from config
ray_actor_options = options["ray_actor_options"]
else:
# Otherwise, get options from graph code (and default to {} if code
# sets options to None)
ray_actor_options = app.deployments[name].ray_actor_options or {}
deployment_env = ray_actor_options.get("runtime_env", {})
merged_env = override_runtime_envs_except_env_vars(graph_env, deployment_env)
ray_actor_options.update({"runtime_env": merged_env})
options["ray_actor_options"] = ray_actor_options
# Update the deployment's options
app.deployments[name].set_options(**options)
# Run the graph locally on the cluster
serve.start(_override_controller_namespace="serve")

View file

@ -1,6 +1,7 @@
from contextlib import contextmanager
import sys
import os
import time
import subprocess
from tempfile import NamedTemporaryFile
import requests
@ -468,11 +469,10 @@ class TestDeployApp:
config = ServeApplicationSchema.parse_obj(self.get_test_config())
client.deploy_app(config)
wait_for_condition(
lambda: client.get_serve_status().app_status.deployment_timestamp > 0
)
assert client.get_serve_status().app_status.deployment_timestamp > 0
first_deploy_time = client.get_serve_status().app_status.deployment_timestamp
time.sleep(0.1)
config = self.get_test_config()
config["deployments"] = [
@ -483,8 +483,8 @@ class TestDeployApp:
]
client.deploy_app(ServeApplicationSchema.parse_obj(config))
wait_for_condition(
lambda: client.get_serve_status().app_status.deployment_timestamp
assert (
client.get_serve_status().app_status.deployment_timestamp
> first_deploy_time
)
assert client.get_serve_status().app_status.status in {
@ -521,6 +521,42 @@ class TestDeployApp:
== "4 pizzas please!"
)
def test_deploy_app_runtime_env(self, client: ServeControllerClient):
config_template = {
"import_path": "conditional_dag.serve_dag",
"runtime_env": {
"working_dir": (
"https://github.com/ray-project/test_dag/archive/"
"cc246509ba3c9371f8450f74fdc18018428630bd.zip"
)
},
}
config1 = ServeApplicationSchema.parse_obj(config_template)
client.deploy_app(config1)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "0 pizzas please!"
)
# Override the configuration
config_template["deployments"] = [
{
"name": "Adder",
"ray_actor_options": {
"runtime_env": {"env_vars": {"override_increment": "1"}}
},
}
]
config2 = ServeApplicationSchema.parse_obj(config_template)
client.deploy_app(config2)
wait_for_condition(
lambda: requests.post("http://localhost:8000/", json=["ADD", 2]).json()
== "3 pizzas please!"
)
def test_controller_recover_and_delete():
"""Ensure that in-progress deletion can finish even after controller dies."""

View file

@ -14,6 +14,7 @@ from ray.serve.utils import (
serve_encoders,
get_deployment_import_path,
node_id_to_ip_addr,
override_runtime_envs_except_env_vars,
)
@ -142,6 +143,166 @@ class TestGetDeploymentImportPath:
subprocess.check_output(["python", full_fname])
class TestOverrideRuntimeEnvsExceptEnvVars:
def test_merge_empty(self):
assert {"env_vars": {}} == override_runtime_envs_except_env_vars({}, {})
def test_merge_empty_parent(self):
child = {"env_vars": {"test1": "test_val"}, "working_dir": "."}
assert child == override_runtime_envs_except_env_vars({}, child)
def test_merge_empty_child(self):
parent = {"env_vars": {"test1": "test_val"}, "working_dir": "."}
assert parent == override_runtime_envs_except_env_vars(parent, {})
@pytest.mark.parametrize("invalid_env", [None, 0, "runtime_env", set()])
def test_invalid_type(self, invalid_env):
with pytest.raises(TypeError):
override_runtime_envs_except_env_vars(invalid_env, {})
with pytest.raises(TypeError):
override_runtime_envs_except_env_vars({}, invalid_env)
with pytest.raises(TypeError):
override_runtime_envs_except_env_vars(invalid_env, invalid_env)
def test_basic_merge(self):
parent = {
"py_modules": ["http://test.com/test0.zip", "s3://path/test1.zip"],
"working_dir": "gs://path/test2.zip",
"env_vars": {"test": "val", "trial": "val2"},
"pip": ["pandas", "numpy"],
"excludes": ["my_file.txt"],
}
original_parent = parent.copy()
child = {
"py_modules": [],
"working_dir": "s3://path/test1.zip",
"env_vars": {"test": "val", "trial": "val2"},
"pip": ["numpy"],
}
original_child = child.copy()
merged = override_runtime_envs_except_env_vars(parent, child)
assert original_parent == parent
assert original_child == child
assert merged == {
"py_modules": [],
"working_dir": "s3://path/test1.zip",
"env_vars": {"test": "val", "trial": "val2"},
"pip": ["numpy"],
"excludes": ["my_file.txt"],
}
def test_merge_deep_copy(self):
"""Check that the env values are actually deep-copied."""
parent_env_vars = {"parent": "pval"}
child_env_vars = {"child": "cval"}
parent = {"env_vars": parent_env_vars}
child = {"env_vars": child_env_vars}
original_parent = parent.copy()
original_child = child.copy()
merged = override_runtime_envs_except_env_vars(parent, child)
assert merged["env_vars"] == {"parent": "pval", "child": "cval"}
assert original_parent == parent
assert original_child == child
def test_merge_empty_env_vars(self):
env_vars = {"test": "val", "trial": "val2"}
non_empty = {"env_vars": {"test": "val", "trial": "val2"}}
empty = {}
assert (
env_vars
== override_runtime_envs_except_env_vars(non_empty, empty)["env_vars"]
)
assert (
env_vars
== override_runtime_envs_except_env_vars(empty, non_empty)["env_vars"]
)
assert {} == override_runtime_envs_except_env_vars(empty, empty)["env_vars"]
def test_merge_env_vars(self):
parent = {
"py_modules": ["http://test.com/test0.zip", "s3://path/test1.zip"],
"working_dir": "gs://path/test2.zip",
"env_vars": {"parent": "pval", "override": "old"},
"pip": ["pandas", "numpy"],
"excludes": ["my_file.txt"],
}
child = {
"py_modules": [],
"working_dir": "s3://path/test1.zip",
"env_vars": {"child": "cval", "override": "new"},
"pip": ["numpy"],
}
merged = override_runtime_envs_except_env_vars(parent, child)
assert merged == {
"py_modules": [],
"working_dir": "s3://path/test1.zip",
"env_vars": {"parent": "pval", "child": "cval", "override": "new"},
"pip": ["numpy"],
"excludes": ["my_file.txt"],
}
def test_inheritance_regression(self):
"""Check if the general Ray runtime_env inheritance behavior matches.
override_runtime_envs_except_env_vars should match the general Ray
runtime_env inheritance behavior. This test checks if that behavior
has changed, which would indicate a regression in
override_runtime_envs_except_env_vars. If the runtime_env inheritance
behavior changes, override_runtime_envs_except_env_vars should also
change to match.
"""
with ray.init(
runtime_env={
"py_modules": [
"https://github.com/ray-project/test_dag/archive/"
"cc246509ba3c9371f8450f74fdc18018428630bd.zip"
],
"env_vars": {"var1": "hello"},
}
):
@ray.remote
def check_module():
# Check that Ray job's py_module loaded correctly
from conditional_dag import serve_dag # noqa: F401
return os.getenv("var1")
assert ray.get(check_module.remote()) == "hello"
@ray.remote(
runtime_env={
"py_modules": [
"https://github.com/ray-project/test_deploy_group/archive/"
"67971777e225600720f91f618cdfe71fc47f60ee.zip"
],
"env_vars": {"var2": "world"},
}
)
def test_task():
with pytest.raises(ImportError):
# Check that Ray job's py_module was overwritten
from conditional_dag import serve_dag # noqa: F401
from test_env.shallow_import import ShallowClass
if ShallowClass()() == "Hello shallow world!":
return os.getenv("var1") + " " + os.getenv("var2")
assert ray.get(test_task.remote()) == "hello world"
if __name__ == "__main__":
import sys

View file

@ -6,8 +6,9 @@ import pickle
import random
import string
import time
from typing import Iterable, List, Tuple
from typing import Iterable, List, Dict, Tuple
import os
import copy
import traceback
from enum import Enum
import __main__
@ -327,6 +328,53 @@ def parse_import_path(import_path: str):
return ".".join(nodes[:-1]), nodes[-1]
def override_runtime_envs_except_env_vars(parent_env: Dict, child_env: Dict) -> Dict:
"""Creates a runtime_env dict by merging a parent and child environment.
This method is not destructive. It leaves the parent and child envs
the same.
The merge is a shallow update where the child environment inherits the
parent environment's settings. If the child environment specifies any
env settings, those settings take precdence over the parent.
- Note: env_vars are a special case. The child's env_vars are combined
with the parent.
Args:
parent_env: The environment to inherit settings from.
child_env: The environment with override settings.
Returns: A new dictionary containing the merged runtime_env settings.
Raises:
TypeError: If a dictionary is not passed in for parent_env or child_env.
"""
if not isinstance(parent_env, Dict):
raise TypeError(
f'Got unexpected type "{type(parent_env)}" for parent_env. '
"parent_env must be a dictionary."
)
if not isinstance(child_env, Dict):
raise TypeError(
f'Got unexpected type "{type(child_env)}" for child_env. '
"child_env must be a dictionary."
)
defaults = copy.deepcopy(parent_env)
overrides = copy.deepcopy(child_env)
default_env_vars = defaults.get("env_vars", {})
override_env_vars = overrides.get("env_vars", {})
defaults.update(overrides)
default_env_vars.update(override_env_vars)
defaults["env_vars"] = default_env_vars
return defaults
class JavaActorHandleProxy:
"""Wraps actor handle and translate snake_case to camelCase."""