[serve] Support for imported backends (#12923)

This commit is contained in:
Edward Oakes 2020-12-18 15:49:24 -06:00 committed by GitHub
parent 92812f2e8a
commit 3521e74f3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 151 additions and 3 deletions

View file

@ -327,3 +327,11 @@ as shown below.
:mod:`client.create_backend <ray.serve.api.Client.create_backend>` by
default.
The dependencies required in the backend may be different than
the dependencies installed in the driver program (the one running Serve API
calls). In this case, you can use an
:mod:`ImportedBackend <ray.serve.backends.ImportedBackend>` to specify a
backend based on a class that is installed in the Python environment that
the workers will run in. Example:
.. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py

View file

@ -31,3 +31,6 @@ objects instead of Flask requests.
Batching Requests
-----------------
.. autofunction:: ray.serve.accept_batch
Built-in Backends
.. autoclass:: ray.serve.backends.ImportedBackend

View file

@ -119,6 +119,14 @@ py_test(
deps = [":serve_lib"],
)
py_test(
name = "test_imported_backend",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Runs test_api and test_failure with injected failures in the controller.
# TODO(simon): Tests are disabled until #11683 is fixed.

View file

@ -0,0 +1,33 @@
from ray.serve.utils import import_class
class ImportedBackend:
"""Factory for a class that will dynamically import a backend class.
This is intended to be used when the source code for a backend is
installed in the worker environment but not the driver.
Intended usage:
>>> client = serve.connect()
>>> client.create_backend("b", ImportedBackend("module.Class"), *args)
This will import module.Class on the worker and proxy all relevant methods
to it.
"""
def __new__(cls, class_path):
class ImportedBackend:
def __init__(self, *args, **kwargs):
self.wrapped = import_class(class_path)(*args, **kwargs)
def reconfigure(self, *args, **kwargs):
# NOTE(edoakes): we check that the reconfigure method is
# present if the user specifies a user_config, so we need to
# proxy it manually.
return self.wrapped.reconfigure(*args, **kwargs)
def __getattr__(self, attr):
"""Proxy all other methods to the wrapper class."""
return getattr(self.wrapped, attr)
return ImportedBackend

View file

@ -1,10 +1,8 @@
import requests
import ray
from ray import serve
from ray.serve import CondaEnv
import tensorflow as tf
ray.init()
client = serve.start()

View file

@ -0,0 +1,12 @@
import requests
from ray import serve
from ray.serve.backends import ImportedBackend
client = serve.start()
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
client.create_backend("imported", backend_class, "input_arg")
client.create_endpoint("imported", backend="imported", route="/imported")
print(requests.get("http://127.0.0.1:8000/imported").text)

View file

@ -0,0 +1,29 @@
import ray
from ray.serve.backends import ImportedBackend
from ray.serve.config import BackendConfig
def test_imported_backend(serve_instance):
client = serve_instance
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
config = BackendConfig(user_config="config")
client.create_backend(
"imported", backend_class, "input_arg", config=config)
client.create_endpoint("imported", backend="imported")
# Basic sanity check.
handle = client.get_handle("imported")
assert ray.get(handle.remote()) == {"arg": "input_arg", "config": "config"}
# Check that updating backend config works.
client.update_backend_config(
"imported", BackendConfig(user_config="new_config"))
assert ray.get(handle.remote()) == {
"arg": "input_arg",
"config": "new_config"
}
# Check that other call methods work.
handle = handle.options(method_name="other_method")
assert ray.get(handle.remote("hello")) == "hello"

View file

@ -6,9 +6,10 @@ from copy import deepcopy
import numpy as np
import pytest
import ray
from ray.serve.utils import (ServeEncoder, chain_future, unpack_future,
try_schedule_resources_on_nodes,
get_conda_env_dir)
get_conda_env_dir, import_class)
def test_bytes_encoder():
@ -125,6 +126,21 @@ def test_get_conda_env_dir(tmp_path):
os.environ["CONDA_PREFIX"] = ""
def test_import_class():
assert import_class("ray.serve.Client") == ray.serve.api.Client
assert import_class("ray.serve.api.Client") == ray.serve.api.Client
policy_cls = import_class("ray.serve.controller.TrafficPolicy")
assert policy_cls == ray.serve.controller.TrafficPolicy
policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5})
with pytest.raises(ValueError):
policy.set_traffic_dict({"endpoint1": 0.5, "endpoint2": 0.6})
policy.set_traffic_dict({"endpoint1": 0.4, "endpoint2": 0.6})
print(repr(policy))
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -1,5 +1,6 @@
import asyncio
from functools import singledispatch
import importlib
from itertools import groupby
import json
import logging
@ -342,3 +343,43 @@ def get_node_id_for_actor(actor_handle):
"""Given an actor handle, return the node id it's placed on."""
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
def import_class(full_path: str):
"""Given a full import path to a class name, return the imported class.
For example, the following are equivalent:
MyClass = import_class("module.submodule.MyClass")
from module.submodule import MyClass
Returns:
Imported class
"""
last_period_idx = full_path.rfind(".")
class_name = full_path[last_period_idx + 1:]
module_name = full_path[:last_period_idx]
module = importlib.import_module(module_name)
return getattr(module, class_name)
class MockImportedBackend:
"""Used for testing backends.ImportedBackend.
This is necessary because we need the class to be installed in the worker
processes. We could instead mock out importlib but doing so is messier and
reduces confidence in the test (it isn't truly end-to-end).
"""
def __init__(self, arg):
self.arg = arg
self.config = None
def reconfigure(self, config):
self.config = config
def __call__(self, *args):
return {"arg": self.arg, "config": self.config}
def other_method(self, request):
return request.data