[Serve] Rewrite Ray.Serve From Scratch (#5562)

* Commit and format files

* address stylistic concerns

* Replcae "Usage" by "Example" in doc

* Rename srv to serve

* Add serve to CI process; Fix 3.5 compat

* Improve determine_tests_to_run.py

* Quick cosmetic for determien_tests

* Address comments

* Address comments

* Address comment

* Fix typos and grammar

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* Update python/ray/experimental/serve/global_state.py

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* Use __init__ for Query and WorkIntent class

* Remove dataclasses dependency

* Rename oid to object_id for clarity

* Rename produce->enqueue_request, consume->dequeue_request

* Address last round of comment
This commit is contained in:
Simon Mo 2019-09-13 21:36:56 -07:00 committed by GitHub
parent 4c964c0941
commit 5f88823c49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 1428 additions and 10 deletions

View file

@ -175,6 +175,7 @@ script:
# ray tests
# Python3.5+ only. Otherwise we will get `SyntaxError` regardless of how we set the tester.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v --durations=5 --timeout=300 python/ray/experimental/test/async_test.py; fi
- if [ $RAY_CI_SERVE_AFFECTED == "1" ]; then python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v --durations=5 --timeout=300 python/ray/experimental/serve/tests; fi
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=10 --timeout=300 python/ray/tests --ignore=python/ray/tests/perf_integration_tests; fi
deploy:

View file

@ -5,6 +5,9 @@ from __future__ import print_function
import os
import subprocess
import sys
from functools import partial
from pprint import pformat
def list_changed_files(commit_range):
@ -30,6 +33,7 @@ if __name__ == "__main__":
RAY_CI_TUNE_AFFECTED = 0
RAY_CI_RLLIB_AFFECTED = 0
RAY_CI_SERVE_AFFECTED = 0
RAY_CI_JAVA_AFFECTED = 0
RAY_CI_PYTHON_AFFECTED = 0
RAY_CI_LINUX_WHEELS_AFFECTED = 0
@ -40,6 +44,8 @@ if __name__ == "__main__":
files = list_changed_files(os.environ["TRAVIS_COMMIT_RANGE"].replace(
"...", ".."))
print(pformat(files), file=sys.stderr)
skip_prefix_list = [
"doc/", "examples/", "dev/", "docker/", "kubernetes/", "site/"
]
@ -54,9 +60,14 @@ if __name__ == "__main__":
RAY_CI_RLLIB_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
elif changed_file.startswith("python/ray/experimental/serve"):
RAY_CI_SERVE_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
elif changed_file.startswith("python/"):
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_RLLIB_AFFECTED = 1
RAY_CI_SERVE_AFFECTED = 1
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
@ -70,6 +81,7 @@ if __name__ == "__main__":
elif changed_file.startswith("src/"):
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_RLLIB_AFFECTED = 1
RAY_CI_SERVE_AFFECTED = 1
RAY_CI_JAVA_AFFECTED = 1
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
@ -77,6 +89,7 @@ if __name__ == "__main__":
else:
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_RLLIB_AFFECTED = 1
RAY_CI_SERVE_AFFECTED = 1
RAY_CI_JAVA_AFFECTED = 1
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
@ -84,16 +97,22 @@ if __name__ == "__main__":
else:
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_RLLIB_AFFECTED = 1
RAY_CI_SERVE_AFFECTED = 1
RAY_CI_JAVA_AFFECTED = 1
RAY_CI_PYTHON_AFFECTED = 1
RAY_CI_LINUX_WHEELS_AFFECTED = 1
RAY_CI_MACOS_WHEELS_AFFECTED = 1
print("export RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED))
print("export RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED))
print("export RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED))
print("export RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED))
print("export RAY_CI_LINUX_WHEELS_AFFECTED={}"
# Log the modified environment variables visible in console.
for output_stream in [sys.stdout, sys.stderr]:
_print = partial(print, file=output_stream)
_print("export RAY_CI_TUNE_AFFECTED={}".format(RAY_CI_TUNE_AFFECTED))
_print("export RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED))
_print("export RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED))
_print("export RAY_CI_JAVA_AFFECTED={}".format(RAY_CI_JAVA_AFFECTED))
_print(
"export RAY_CI_PYTHON_AFFECTED={}".format(RAY_CI_PYTHON_AFFECTED))
_print("export RAY_CI_LINUX_WHEELS_AFFECTED={}"
.format(RAY_CI_LINUX_WHEELS_AFFECTED))
print("export RAY_CI_MACOS_WHEELS_AFFECTED={}"
_print("export RAY_CI_MACOS_WHEELS_AFFECTED={}"
.format(RAY_CI_MACOS_WHEELS_AFFECTED))

View file

@ -34,7 +34,8 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q scipy tensorflow cython==0.29.0 gym opencv-python-headless pyyaml pandas==0.24.2 requests \
feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp
feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp \
uvicorn dataclasses pygments werkzeug
elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
# Install miniconda.
wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv
@ -48,7 +49,8 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install -q cython==0.29.0 tensorflow gym opencv-python-headless pyyaml pandas==0.24.2 requests \
feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp
feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout flaky networkx tabulate psutil aiohttp \
uvicorn dataclasses pygments werkzeug
elif [[ "$LINT" == "1" ]]; then
sudo apt-get update
sudo apt-get install -y build-essential curl unzip

View file

@ -0,0 +1,12 @@
import sys
if sys.version_info < (3, 0):
raise ImportError("serve is Python 3 only.")
from ray.experimental.serve.api import (init, create_backend, create_endpoint,
link, split, rollback, get_handle,
global_state) # noqa: E402
__all__ = [
"init", "create_backend", "create_endpoint", "link", "split", "rollback",
"get_handle", "global_state"
]

View file

@ -0,0 +1,187 @@
import inspect
import numpy as np
import ray
from ray.experimental.serve.task_runner import RayServeMixin, TaskRunnerActor
from ray.experimental.serve.utils import pformat_color_json, logger
from ray.experimental.serve.global_state import GlobalState
global_state = GlobalState()
def init(blocking=False, object_store_memory=int(1e8)):
"""Initialize a serve cluster.
Calling `ray.init` before `serve.init` is optional. When there is not a ray
cluster initialized, serve will call `ray.init` with `object_store_memory`
requirement.
Args:
blocking (bool): If true, the function will wait for the HTTP server to
be healthy before returns.
object_store_memory (int): Allocated shared memory size in bytes. The
default is 100MiB. The default is kept low for latency stability
reason.
"""
if not ray.is_initialized():
ray.init(object_store_memory=object_store_memory)
# NOTE(simon): Currently the initialization order is fixed.
# HTTP server depends on the API server.
global_state.init_api_server()
global_state.init_router()
global_state.init_http_server()
if blocking:
global_state.wait_until_http_ready()
def create_endpoint(endpoint_name, route_expression, blocking=True):
"""Create a service endpoint given route_expression.
Args:
endpoint_name (str): A name to associate to the endpoint. It will be
used as key to set traffic policy.
route_expression (str): A string begin with "/". HTTP server will use
the string to match the path.
blocking (bool): If true, the function will wait for service to be
registered before returning
"""
future = global_state.kv_store_actor_handle.register_service.remote(
route_expression, endpoint_name)
if blocking:
ray.get(future)
global_state.registered_endpoints.add(endpoint_name)
def create_backend(func_or_class, backend_tag, *actor_init_args):
"""Create a backend using func_or_class and assign backend_tag.
Args:
func_or_class (callable, class): a function or a class implements
__call__ protocol.
backend_tag (str): a unique tag assign to this backend. It will be used
to associate services in traffic policy.
*actor_init_args (optional): the argument to pass to the class
initialization method.
"""
if inspect.isfunction(func_or_class):
runner = TaskRunnerActor.remote(func_or_class)
elif inspect.isclass(func_or_class):
# Python inheritance order is right-to-left. We put RayServeMixin
# on the left to make sure its methods are not overriden.
@ray.remote
class CustomActor(RayServeMixin, func_or_class):
pass
runner = CustomActor.remote(*actor_init_args)
else:
raise TypeError(
"Backend must be a function or class, it is {}.".format(
type(func_or_class)))
global_state.backend_actor_handles.append(runner)
runner._ray_serve_setup.remote(backend_tag,
global_state.router_actor_handle)
runner._ray_serve_main_loop.remote(runner)
global_state.registered_backends.add(backend_tag)
def link(endpoint_name, backend_tag):
"""Associate a service endpoint with backend tag.
Example:
>>> serve.link("service-name", "backend:v1")
Note:
This is equivalent to
>>> serve.split("service-name", {"backend:v1": 1.0})
"""
assert endpoint_name in global_state.registered_endpoints
global_state.router_actor_handle.link.remote(endpoint_name, backend_tag)
global_state.policy_action_history[endpoint_name].append({backend_tag: 1})
def split(endpoint_name, traffic_policy_dictionary):
"""Associate a service endpoint with traffic policy.
Example:
>>> serve.split("service-name", {
"backend:v1": 0.5,
"backend:v2": 0.5
})
Args:
endpoint_name (str): A registered service endpoint.
traffic_policy_dictionary (dict): a dictionary maps backend names
to their traffic weights. The weights must sum to 1.
"""
# Perform dictionary checks
assert endpoint_name in global_state.registered_endpoints
assert isinstance(traffic_policy_dictionary,
dict), "Traffic policy must be dictionary"
prob = 0
for backend, weight in traffic_policy_dictionary.items():
prob += weight
assert (backend in global_state.registered_backends
), "backend {} is not registered".format(backend)
assert np.isclose(
prob, 1,
atol=0.02), "weights must sum to 1, currently it sums to {}".format(
prob)
global_state.router_actor_handle.set_traffic.remote(
endpoint_name, traffic_policy_dictionary)
global_state.policy_action_history[endpoint_name].append(
traffic_policy_dictionary)
def rollback(endpoint_name):
"""Rollback a traffic policy decision.
Args:
endpoint_name (str): A registered service endpoint.
"""
assert endpoint_name in global_state.registered_endpoints
action_queues = global_state.policy_action_history[endpoint_name]
cur_policy, prev_policy = action_queues[-1], action_queues[-2]
logger.warning("""
Current traffic policy is:
{cur_policy}
Will rollback to:
{prev_policy}
""".format(
cur_policy=pformat_color_json(cur_policy),
prev_policy=pformat_color_json(prev_policy)))
action_queues.pop()
global_state.router_actor_handle.set_traffic.remote(
endpoint_name, prev_policy)
def get_handle(endpoint_name):
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
endpoint_name (str): A registered service endpoint.
Returns:
RayServeHandle
"""
assert endpoint_name in global_state.registered_endpoints
# Delay import due to it's dependency on global_state
from ray.experimental.serve.handle import RayServeHandle
return RayServeHandle(global_state.router_actor_handle, endpoint_name)

View file

@ -0,0 +1,2 @@
#: The interval which http server refreshes its routing table
HTTP_ROUTER_CHECKER_INTERVAL_S = 2

View file

@ -0,0 +1,28 @@
"""
Example service that prints out http context.
"""
import time
import requests
from ray.experimental import serve
from ray.experimental.serve.utils import pformat_color_json
def echo(context):
return context
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_backend(echo, "echo:v1")
serve.link("my_endpoint", "echo:v1")
while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)

View file

@ -0,0 +1,41 @@
"""
Example actor that adds message to the end of query_string.
"""
import time
import requests
from werkzeug import urls
from ray.experimental import serve
from ray.experimental.serve.utils import pformat_color_json
class EchoActor:
def __init__(self, message):
self.message = message
def __call__(self, context):
query_string_dict = urls.url_decode(context["query_string"])
message = ""
message += query_string_dict.get("message", "")
message += " "
message += self.message
return message
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_backend(EchoActor, "echo:v1", "world")
serve.link("my_endpoint", "echo:v1")
while True:
resp = requests.get("http://127.0.0.1:8000/echo?message=hello").json()
print(pformat_color_json(resp))
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)

View file

@ -0,0 +1,44 @@
"""
Example of error handling mechanism in ray serve.
We are going to define a buggy function that raise some exception:
>>> def echo(_):
raise Exception("oh no")
The expected behavior is:
- HTTP server should respond with "internal error" in the response JSON
- ray.get(handle.remote(33)) should raise RayTaskError with traceback.
This shows that error is hidden from HTTP side but always visible when calling
from Python.
"""
import time
import requests
import ray
from ray.experimental import serve
from ray.experimental.serve.utils import pformat_color_json
def echo(_):
raise Exception("Something went wrong...")
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_backend(echo, "echo:v1")
serve.link("my_endpoint", "echo:v1")
for _ in range(2):
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)
handle = serve.get_handle("my_endpoint")
ray.get(handle.remote(33))

View file

@ -0,0 +1,50 @@
"""
Example rollback action in ray serve. We first deploy only v1, then set a
50/50 deployment between v1 and v2, and finally roll back to only v1.
"""
import time
import requests
from ray.experimental import serve
from ray.experimental.serve.utils import pformat_color_json
def echo_v1(_):
return "v1"
def echo_v2(_):
return "v2"
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_backend(echo_v1, "echo:v1")
serve.link("my_endpoint", "echo:v1")
for _ in range(3):
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)
serve.create_backend(echo_v2, "echo:v2")
serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
for _ in range(6):
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)
serve.rollback("my_endpoint")
for _ in range(6):
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)

View file

@ -0,0 +1,41 @@
"""
Example of traffic splitting. We will first use echo:v1. Then v1 and v2
will split the incoming traffic evenly.
"""
import time
import requests
from ray.experimental import serve
from ray.experimental.serve.utils import pformat_color_json
def echo_v1(_):
return "v1"
def echo_v2(_):
return "v2"
serve.init(blocking=True)
serve.create_endpoint("my_endpoint", "/echo", blocking=True)
serve.create_backend(echo_v1, "echo:v1")
serve.link("my_endpoint", "echo:v1")
for _ in range(3):
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)
serve.create_backend(echo_v2, "echo:v2")
serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print("...Sleeping for 2 seconds...")
time.sleep(2)

View file

@ -0,0 +1,84 @@
import time
from collections import defaultdict, deque
import ray
from ray.experimental.serve.kv_store_service import KVStoreProxyActor
from ray.experimental.serve.queues import CentralizedQueuesActor
from ray.experimental.serve.utils import logger
from ray.experimental.serve.server import HTTPActor
# TODO(simon): Global state currently is designed to resides in the driver
# process. In the next iteration, we will move all mutable states into
# two actors: (1) namespaced key-value store backed by persistent store
# (2) actor supervisors holding all actor handles and is responsible
# for new actor instantiation and dead actor termination.
LOG_PREFIX = "[Global State] "
class GlobalState:
"""Encapsulate all global state in the serving system.
Warning:
Currently the state resides inside driver process. The state will be
moved into a key value stored service AND a supervisor service.
"""
def __init__(self):
#: holds all actor handles.
self.backend_actor_handles = []
#: actor handle to KV store actor
self.kv_store_actor_handle = None
#: actor handle to HTTP server
self.http_actor_handle = None
#: actor handle the router actor
self.router_actor_handle = None
#: Set[str] list of backend names, used for deduplication
self.registered_backends = set()
#: Set[str] list of service endpoint names, used for deduplication
self.registered_endpoints = set()
#: Mapping of endpoints -> a stack of traffic policy
self.policy_action_history = defaultdict(deque)
#: HTTP address. Currently it's hard coded to localhost with port 8000
# In future iteration, HTTP server will be started on every node and
# use random/available port in a pre-defined port range. TODO(simon)
self.http_address = ""
def init_api_server(self):
logger.info(LOG_PREFIX + "Initalizing routing table")
self.kv_store_actor_handle = KVStoreProxyActor.remote()
logger.info((LOG_PREFIX + "Health checking routing table {}").format(
ray.get(self.kv_store_actor_handle.get_request_count.remote())), )
def init_http_server(self):
logger.info(LOG_PREFIX + "Initializing HTTP server")
self.http_actor_handle = HTTPActor.remote(self.kv_store_actor_handle,
self.router_actor_handle)
self.http_actor_handle.run.remote(host="0.0.0.0", port=8000)
self.http_address = "http://localhost:8000"
def init_router(self):
logger.info(LOG_PREFIX + "Initializing queuing system")
self.router_actor_handle = CentralizedQueuesActor.remote()
self.router_actor_handle.register_self_handle.remote(
self.router_actor_handle)
def wait_until_http_ready(self, num_retries=5, backoff_time_s=1):
routing_table_request_count = 0
retries = num_retries
while not routing_table_request_count:
routing_table_request_count = (ray.get(
self.kv_store_actor_handle.get_request_count.remote()))
logger.debug((LOG_PREFIX + "Checking if HTTP server is ready."
"{} retries left.").format(retries))
time.sleep(backoff_time_s)
retries -= 1
if retries == 0:
raise Exception(
"HTTP server not ready after {} retries.".format(
num_retries))

View file

@ -0,0 +1,64 @@
import ray
from ray.experimental import serve
class RayServeHandle:
"""A handle to a service endpoint.
Invoking this endpoint with .remote is equivalent to pinging
an HTTP endpoint.
Example:
>>> handle = serve.get_handle("my_endpoint")
>>> handle
RayServeHandle(
Endpoint="my_endpoint",
URL="...",
Traffic=...
)
>>> handle.remote(my_request_content)
ObjectID(...)
>>> ray.get(handle.remote(...))
# result
>>> ray.get(handle.remote(let_it_crash_request))
# raises RayTaskError Exception
"""
def __init__(self, router_handle, endpoint_name):
self.router_handle = router_handle
self.endpoint_name = endpoint_name
def remote(self, *args):
# TODO(simon): Support kwargs once #5606 is merged.
result_object_id_bytes = ray.get(
self.router_handle.enqueue_request.remote(self.endpoint_name,
*args))
return ray.ObjectID(result_object_id_bytes)
def get_traffic_policy(self):
# TODO(simon): This method is implemented via checking global state
# because we are sure handle and global_state are in the same process.
# However, once global_state is deprecated, this method need to be
# updated accordingly.
history = serve.global_state.policy_action_history[self.endpoint_name]
if len(history):
return history[-1]
else:
return None
def get_http_endpoint(self):
return serve.global_state.http_address
def __repr__(self):
return """
RayServeHandle(
Endpoint="{endpoint_name}",
URL="{http_endpoint}/{endpoint_name},
Traffic={traffic_policy}
)
""".format(endpoint_name=self.endpoint_name,
http_endpoint=self.get_http_endpoint(),
traffic_policy=self.get_traffic_policy())
# TODO(simon): a convenience function that dumps equivalent requests
# code for a given call.

View file

@ -0,0 +1,173 @@
import json
from abc import ABC
import ray
import ray.experimental.internal_kv as ray_kv
from ray.experimental.serve.utils import logger
class NamespacedKVStore(ABC):
"""Abstract base class for a namespaced key-value store.
The idea is that multiple key-value stores can be created while sharing
the same storage system. The keys of each instance are namespaced to avoid
object_id key collision.
Example:
>>> store_ns1 = NamespacedKVStore(namespace="ns1")
>>> store_ns2 = NamespacedKVStore(namespace="ns2")
# Two stores can share the same connection like Redis or SQL Table
>>> store_ns1.put("same-key", 1)
>>> store_ns1.get("same-key")
1
>>> store_ns2.put("same-key", 2)
>>> store_ns2.get("same-key", 2)
2
"""
def __init__(self, namespace):
raise NotImplementedError()
def get(self, key):
"""Retrieve the value for the given key.
Args:
key (str)
"""
raise NotImplementedError()
def put(self, key, value):
"""Serialize the value and store it under the given key.
Args:
key (str)
value (object): any serializable object. The serialization method
is determined by the subclass implementation.
"""
raise NotImplementedError()
def as_dict(self):
"""Return the entire namespace as a dictionary.
Returns:
data (dict): key value pairs in current namespace
"""
raise NotImplementedError()
class InMemoryKVStore(NamespacedKVStore):
"""A reference implementation used for testing."""
def __init__(self, namespace):
self.data = dict()
# Namespace is ignored, because each namespace is backed by
# an in-memory Python dictionary.
self.namespace = namespace
def get(self, key):
return self.data[key]
def put(self, key, value):
self.data[key] = value
def as_dict(self):
return self.data.copy()
class RayInternalKVStore(NamespacedKVStore):
"""A NamespacedKVStore implementation using ray's `internal_kv`."""
def __init__(self, namespace):
assert ray_kv._internal_kv_initialized()
self.index_key = "RAY_SERVE_INDEX"
self.namespace = namespace
self._put(self.index_key, [])
def _format_key(self, key):
return "{ns}-{key}".format(ns=self.namespace, key=key)
def _remove_format_key(self, formatted_key):
return formatted_key.replace(self.namespace + "-", "", 1)
def _serialize(self, obj):
return json.dumps(obj)
def _deserialize(self, buffer):
return json.loads(buffer)
def _put(self, key, value):
ray_kv._internal_kv_put(
self._format_key(self._serialize(key)),
self._serialize(value),
overwrite=True,
)
def _get(self, key):
return self._deserialize(
ray_kv._internal_kv_get(self._format_key(self._serialize(key))))
def get(self, key):
return self._get(key)
def put(self, key, value):
assert isinstance(key, str), "Key must be a string."
self._put(key, value)
all_keys = set(self._get(self.index_key))
all_keys.add(key)
self._put(self.index_key, list(all_keys))
def as_dict(self):
data = {}
all_keys = self._get(self.index_key)
for key in all_keys:
data[self._remove_format_key(key)] = self._get(key)
return data
class KVStoreProxy:
def __init__(self, kv_class=InMemoryKVStore):
self.routing_table = kv_class(namespace="routes")
self.request_count = 0
def register_service(self, route: str, service: str):
"""Create an entry in the routing table
Args:
route: http path name. Must begin with '/'.
service: service name. This is the name http actor will push
the request to.
"""
logger.debug("[KV] Registering route {} to service {}.".format(
route, service))
self.routing_table.put(route, service)
def list_service(self):
"""Returns the routing table."""
self.request_count += 1
table = self.routing_table.as_dict()
return table
def get_request_count(self):
"""Return the number of requests that fetched the routing table.
This method is used for two purpose:
1. Make sure HTTP server has started and healthy. Incremented request
count means HTTP server is actively fetching routing table.
2. Make sure HTTP server does not have stale routing table. This number
should be incremented every HTTP_ROUTER_CHECKER_INTERVAL_S seconds.
Supervisor should check this number as indirect indicator of http
server's health.
"""
return self.request_count
@ray.remote
class KVStoreProxyActor(KVStoreProxy):
def __init__(self, kv_class=RayInternalKVStore):
super().__init__(kv_class=kv_class)

View file

@ -0,0 +1,155 @@
from collections import defaultdict, deque
import numpy as np
import ray
from ray.experimental.serve.utils import get_custom_object_id, logger
class Query:
def __init__(self, request_body, result_object_id=None):
self.request_body = request_body
if result_object_id is None:
self.result_object_id = get_custom_object_id()
else:
self.result_object_id = result_object_id
class WorkIntent:
def __init__(self, work_object_id=None):
if work_object_id is None:
self.work_object_id = get_custom_object_id()
else:
self.work_object_id = work_object_id
class CentralizedQueues:
"""A router that routes request to available workers.
Router aceepts each request from the `enqueue_request` method and enqueues
it. It also accepts worker request to work (called work_intention in code)
from workers via the `dequeue_request` method. The traffic policy is used
to match requests with their corresponding workers.
Behavior:
>>> # psuedo-code
>>> queue = CentralizedQueues()
>>> queue.enqueue_request('service-name', data)
# nothing happens, request is queued.
# returns result ObjectID, which will contains the final result
>>> queue.dequeue_request('backend-1')
# nothing happens, work intention is queued.
# return work ObjectID, which will contains the future request payload
>>> queue.link('service-name', 'backend-1')
# here the enqueue_requester is matched with worker, request
# data is put into work ObjectID, and the worker processes the request
# and store the result into result ObjectID
Traffic policy splits the traffic among different workers
probabilistically:
1. When all backends are ready to receive traffic, we will randomly
choose a backend based on the weights assigned by the traffic policy
dictionary.
2. When more than 1 but not all backends are ready, we will normalize the
weights of the ready backends to 1 and choose a backend via sampling.
3. When there is only 1 backend ready, we will only use that backend.
"""
def __init__(self):
# service_name -> request queue
self.queues = defaultdict(deque)
# service_name -> traffic_policy
self.traffic = defaultdict(dict)
# backend_name -> worker queue
self.workers = defaultdict(deque)
def enqueue_request(self, service, request_data):
query = Query(request_data)
self.queues[service].append(query)
self.flush()
return query.result_object_id.binary()
def dequeue_request(self, backend):
intention = WorkIntent()
self.workers[backend].append(intention)
self.flush()
return intention.work_object_id.binary()
def link(self, service, backend):
logger.debug("Link %s with %s", service, backend)
self.traffic[service][backend] = 1.0
self.flush()
def set_traffic(self, service, traffic_dict):
logger.debug("Setting traffic for service %s to %s", service,
traffic_dict)
self.traffic[service] = traffic_dict
self.flush()
def flush(self):
"""In the default case, flush calls ._flush.
When this class is a Ray actor, .flush can be scheduled as a remote
method invocation.
"""
self._flush()
def _get_available_backends(self, service):
backends_in_policy = set(self.traffic[service].keys())
available_workers = set((backend
for backend, queues in self.workers.items()
if len(queues) > 0))
return list(backends_in_policy.intersection(available_workers))
def _flush(self):
for service, queue in self.queues.items():
ready_backends = self._get_available_backends(service)
while len(queue) and len(ready_backends):
# Fast path, only one backend available.
if len(ready_backends) == 1:
backend = ready_backends[0]
request, work = (queue.popleft(),
self.workers[backend].popleft())
ray.worker.global_worker.put_object(
work.work_object_id, request)
# We have more than one backend available.
# We will roll a dice among the multiple backends.
else:
backend_weights = np.array([
self.traffic[service][backend_name]
for backend_name in ready_backends
])
# Normalize the weights to 1.
backend_weights /= backend_weights.sum()
chosen_backend = np.random.choice(
ready_backends, p=backend_weights).squeeze()
request, work = (
queue.popleft(),
self.workers[chosen_backend].popleft(),
)
ray.worker.global_worker.put_object(
work.work_object_id, request)
ready_backends = self._get_available_backends(service)
@ray.remote
class CentralizedQueuesActor(CentralizedQueues):
self_handle = None
def register_self_handle(self, handle_to_this_actor):
self.self_handle = handle_to_this_actor
def flush(self):
if self.self_handle:
self.self_handle._flush.remote()
else:
self._flush()

View file

@ -0,0 +1,125 @@
import asyncio
import json
import uvicorn
import ray
from ray.experimental.async_api import _async_init, as_future
from ray.experimental.serve.utils import BytesEncoder
from ray.experimental.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S
class JSONResponse:
"""ASGI compliant response class.
It is expected to be called in async context and pass along
`scope, receive, send` as in ASGI spec.
>>> await JSONResponse({"k": "v"})(scope, receive, send)
"""
def __init__(self, content=None, status_code=200):
"""Construct a JSON HTTP Response.
Args:
content (optional): Any JSON serializable object.
status_code (int, optional): Default status code is 200.
"""
self.body = self.render(content)
self.status_code = status_code
self.raw_headers = [[b"content-type", b"application/json"]]
def render(self, content):
if content is None:
return b""
if isinstance(content, bytes):
return content
return json.dumps(content, cls=BytesEncoder, indent=2).encode()
async def __call__(self, scope, receive, send):
await send({
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
})
await send({"type": "http.response.body", "body": self.body})
class HTTPProxy:
"""
This class should be instantiated and ran by ASGI server.
>>> import uvicorn
>>> uvicorn.run(HTTPProxy(kv_store_actor_handle, router_handle))
# blocks forever
"""
def __init__(self, kv_store_actor_handle, router_handle):
"""
Args:
kv_store_actor_handle (ray.actor.ActorHandle): handle to routing
table actor. It will be used to populate routing table. It
should implement `handle.list_service()`
router_handle (ray.actor.ActorHandle): actor handle to push request
to. It should implement
`handle.enqueue_request.remote(endpoint, body)`
"""
assert ray.is_initialized()
self.admin_actor = kv_store_actor_handle
self.router = router_handle
self.route_table = dict()
async def route_checker(self, interval):
while True:
try:
self.route_table = await as_future(
self.admin_actor.list_service.remote())
except ray.exceptions.RayletError: # Gracefully handle termination
return
await asyncio.sleep(interval)
async def __call__(self, scope, receive, send):
# NOTE: This implements ASGI protocol specified in
# https://asgi.readthedocs.io/en/latest/specs/index.html
if scope["type"] == "lifespan":
await _async_init()
asyncio.ensure_future(
self.route_checker(interval=HTTP_ROUTER_CHECKER_INTERVAL_S))
return
current_path = scope["path"]
if current_path == "/":
await JSONResponse(self.route_table)(scope, receive, send)
elif current_path in self.route_table:
endpoint_name = self.route_table[current_path]
result_object_id_bytes = await as_future(
self.router.enqueue_request.remote(endpoint_name, scope))
result = await as_future(ray.ObjectID(result_object_id_bytes))
if isinstance(result, ray.exceptions.RayTaskError):
await JSONResponse({
"error": "internal error, please use python API to debug"
})(scope, receive, send)
else:
await JSONResponse({"result": result})(scope, receive, send)
else:
error_message = ("Path {} not found. "
"Please ping http://.../ for routing table"
).format(current_path)
await JSONResponse(
{
"error": error_message
}, status_code=404)(scope, receive, send)
@ray.remote
class HTTPActor:
def __init__(self, kv_store_actor_handle, router_handle):
self.app = HTTPProxy(kv_store_actor_handle, router_handle)
def run(self, host="0.0.0.0", port=8000):
uvicorn.run(self.app, host=host, port=port, lifespan="on")

View file

@ -0,0 +1,96 @@
import traceback
import ray
class TaskRunner:
"""A simple class that runs a function.
The purpose of this class is to model what the most basic actor could be.
That is, a ray serve actor should implement the TaskRunner interface.
"""
def __init__(self, func_to_run):
self.func = func_to_run
def __call__(self, *args):
return self.func(*args)
def wrap_to_ray_error(callable_obj, *args):
"""Utility method that catch and seal exceptions in execution"""
try:
return callable_obj(*args)
except Exception:
traceback_str = ray.utils.format_error_message(traceback.format_exc())
return ray.exceptions.RayTaskError(str(callable_obj), traceback_str)
class RayServeMixin:
"""This mixin class adds the functionality to fetch from router queues.
Warning:
It assumes the main execution method is `__call__` of the user defined
class. This means that serve will call `your_instance.__call__` when
each request comes in. This behavior will be fixed in the future to
allow assigning artibrary methods.
Example:
>>> # Use ray.remote decorator and RayServeMixin
>>> # to make MyClass servable.
>>> @ray.remote
class RayServeActor(RayServeMixin, MyClass):
pass
"""
_ray_serve_self_handle = None
_ray_serve_router_handle = None
_ray_serve_setup_completed = False
_ray_serve_dequeue_requestr_name = None
def _ray_serve_setup(self, my_name, _ray_serve_router_handle):
self._ray_serve_dequeue_requestr_name = my_name
self._ray_serve_router_handle = _ray_serve_router_handle
self._ray_serve_setup_completed = True
def _ray_serve_main_loop(self, my_handle):
assert self._ray_serve_setup_completed
self._ray_serve_self_handle = my_handle
work_token = ray.get(
self._ray_serve_router_handle.dequeue_request.remote(
self._ray_serve_dequeue_requestr_name))
work_item = ray.get(ray.ObjectID(work_token))
# TODO(simon):
# __call__ should be able to take multiple *args and **kwargs.
result = wrap_to_ray_error(self.__call__, work_item.request_body)
result_object_id = work_item.result_object_id
ray.worker.global_worker.put_object(result_object_id, result)
# The worker finished one unit of work.
# It will now tail recursively schedule the main_loop again.
# TODO(simon): remove tail recursion, ask router to callback instead
self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle)
class TaskRunnerBackend(TaskRunner, RayServeMixin):
"""A simple function serving backend
Note that this is not yet an actor. To make it an actor:
>>> @ray.remote
class TaskRunnerActor(TaskRunnerBackend):
pass
Note:
This class is not used in the actual ray serve system. It exists
for documentation purpose.
"""
pass
@ray.remote
class TaskRunnerActor(TaskRunnerBackend):
pass

View file

@ -0,0 +1,21 @@
import pytest
import ray
from ray.experimental import serve
@pytest.fixture(scope="session")
def serve_instance():
serve.init()
serve.global_state.wait_until_http_ready()
yield
@pytest.fixture(scope="session")
def ray_instance():
ray_already_initialized = ray.is_initialized()
if not ray_already_initialized:
ray.init(object_store_memory=int(1e8))
yield
if not ray_already_initialized:
ray.shutdown()

View file

@ -0,0 +1,33 @@
import time
import requests
from flaky import flaky
import ray
from ray.experimental import serve
def delay_rerun(*_):
time.sleep(1)
return True
# flaky test because the routing table might not be populated
@flaky(rerun_filter=delay_rerun)
def test_e2e(serve_instance):
serve.create_endpoint("endpoint", "/api")
result = ray.get(
serve.global_state.kv_store_actor_handle.list_service.remote())
assert result == {"/api": "endpoint"}
assert requests.get("http://127.0.0.1:8000/").json() == result
def echo(i):
return i
serve.create_backend(echo, "echo:v1")
serve.link("endpoint", "echo:v1")
resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
assert resp["path"] == "/api"
assert resp["method"] == "GET"

View file

@ -0,0 +1,72 @@
import ray
from ray.experimental.serve.queues import CentralizedQueues
def test_single_prod_cons_queue(serve_instance):
q = CentralizedQueues()
q.link("svc", "backend")
result_object_id = q.enqueue_request("svc", 1)
work_object_id = q.dequeue_request("backend")
got_work = ray.get(ray.ObjectID(work_object_id))
assert got_work.request_body == 1
ray.worker.global_worker.put_object(got_work.result_object_id, 2)
assert ray.get(ray.ObjectID(result_object_id)) == 2
def test_alter_backend(serve_instance):
q = CentralizedQueues()
result_object_id = q.enqueue_request("svc", 1)
work_object_id = q.dequeue_request("backend-1")
q.set_traffic("svc", {"backend-1": 1})
got_work = ray.get(ray.ObjectID(work_object_id))
assert got_work.request_body == 1
ray.worker.global_worker.put_object(got_work.result_object_id, 2)
assert ray.get(ray.ObjectID(result_object_id)) == 2
result_object_id = q.enqueue_request("svc", 1)
work_object_id = q.dequeue_request("backend-2")
q.set_traffic("svc", {"backend-2": 1})
got_work = ray.get(ray.ObjectID(work_object_id))
assert got_work.request_body == 1
ray.worker.global_worker.put_object(got_work.result_object_id, 2)
assert ray.get(ray.ObjectID(result_object_id)) == 2
def test_split_traffic(serve_instance):
q = CentralizedQueues()
q.enqueue_request("svc", 1)
q.enqueue_request("svc", 1)
q.set_traffic("svc", {})
work_object_id_1 = q.dequeue_request("backend-1")
work_object_id_2 = q.dequeue_request("backend-2")
q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5})
got_work = ray.get(
[ray.ObjectID(work_object_id_1),
ray.ObjectID(work_object_id_2)])
assert [g.request_body for g in got_work] == [1, 1]
def test_probabilities(serve_instance):
q = CentralizedQueues()
[q.enqueue_request("svc", 1) for i in range(100)]
work_object_id_1_s = [
ray.ObjectID(q.dequeue_request("backend-1")) for i in range(100)
]
work_object_id_2_s = [
ray.ObjectID(q.dequeue_request("backend-2")) for i in range(100)
]
q.set_traffic("svc", {"backend-1": 0.1, "backend-2": 0.9})
backend_1_ready_object_ids, _ = ray.wait(
work_object_id_1_s, num_returns=100, timeout=0.0)
backend_2_ready_object_ids, _ = ray.wait(
work_object_id_2_s, num_returns=100, timeout=0.0)
assert len(backend_1_ready_object_ids) < len(backend_2_ready_object_ids)

View file

@ -0,0 +1,27 @@
from ray.experimental.serve.kv_store_service import (InMemoryKVStore,
RayInternalKVStore)
def test_default_in_memory_kv():
kv = InMemoryKVStore("")
kv.put("1", 2)
assert kv.get("1") == 2
kv.put("1", 3)
assert kv.get("1") == 3
assert kv.as_dict() == {"1": 3}
def test_ray_interal_kv(ray_instance):
kv = RayInternalKVStore("")
kv.put("1", 2)
assert kv.get("1") == 2
kv.put("1", 3)
assert kv.get("1") == 3
assert kv.as_dict() == {"1": 3}
kv = RayInternalKVStore("othernamespace")
kv.put("1", 2)
assert kv.get("1") == 2
kv.put("1", 3)
assert kv.get("1") == 3
assert kv.as_dict() == {"1": 3}

View file

@ -0,0 +1,80 @@
import ray
from ray.experimental.serve.queues import CentralizedQueuesActor
from ray.experimental.serve.task_runner import (
RayServeMixin,
TaskRunner,
TaskRunnerActor,
wrap_to_ray_error,
)
def test_runner_basic():
def echo(i):
return i
r = TaskRunner(echo)
assert r(1) == 1
def test_runner_wraps_error():
def echo(i):
return i
assert wrap_to_ray_error(echo, 2) == 2
def error(_):
return 1 / 0
assert isinstance(wrap_to_ray_error(error, 1), ray.exceptions.RayTaskError)
def test_runner_actor(serve_instance):
q = CentralizedQueuesActor.remote()
def echo(i):
return i
CONSUMER_NAME = "runner"
PRODUCER_NAME = "prod"
runner = TaskRunnerActor.remote(echo)
runner._ray_serve_setup.remote(CONSUMER_NAME, q)
runner._ray_serve_main_loop.remote(runner)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
for query in [333, 444, 555]:
result_token = ray.ObjectID(
ray.get(q.enqueue_request.remote(PRODUCER_NAME, query)))
assert ray.get(result_token) == query
def test_ray_serve_mixin(serve_instance):
q = CentralizedQueuesActor.remote()
CONSUMER_NAME = "runner-cls"
PRODUCER_NAME = "prod-cls"
class MyAdder:
def __init__(self, inc):
self.increment = inc
def __call__(self, context):
return context + self.increment
@ray.remote
class CustomActor(MyAdder, RayServeMixin):
pass
runner = CustomActor.remote(3)
runner._ray_serve_setup.remote(CONSUMER_NAME, q)
runner._ray_serve_main_loop.remote(runner)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
for query in [333, 444, 555]:
result_token = ray.ObjectID(
ray.get(q.enqueue_request.remote(PRODUCER_NAME, query)))
assert ray.get(result_token) == query + 3

View file

@ -0,0 +1,9 @@
import json
from ray.experimental.serve.utils import BytesEncoder
def test_bytes_encoder():
data_before = {"inp": {"nest": b"bytes"}}
data_after = {"inp": {"nest": "bytes"}}
assert json.loads(json.dumps(data_before, cls=BytesEncoder)) == data_after

View file

@ -0,0 +1,51 @@
import json
import logging
from pygments import formatters, highlight, lexers
import ray
def _get_logger():
logger = logging.getLogger("ray.serve")
# TODO(simon): Make logging level configurable.
logger.setLevel(logging.INFO)
return logger
logger = _get_logger()
class BytesEncoder(json.JSONEncoder):
"""Allow bytes to be part of the JSON document.
BytesEncoder will walk the JSON tree and decode bytes with utf-8 codec.
Example:
>>> json.dumps({b'a': b'c'}, cls=BytesEncoder)
'{"a":"c"}'
"""
def default(self, o): # pylint: disable=E0202
if isinstance(o, bytes):
return o.decode("utf-8")
return super().default(o)
def get_custom_object_id():
"""Use ray worker API to get computed ObjectID"""
worker = ray.worker.global_worker
object_id = ray._raylet.compute_put_id(worker.current_task_id,
worker.task_context.put_index)
worker.task_context.put_index += 1
return object_id
def pformat_color_json(d):
"""Use pygments to pretty format and colroize dictionary"""
formatted_json = json.dumps(d, sort_keys=True, indent=4)
colorful_json = highlight(formatted_json, lexers.JsonLexer(),
formatters.TerminalFormatter())
return colorful_json

View file

@ -68,6 +68,7 @@ extras = {
],
"debug": ["psutil", "setproctitle", "py-spy"],
"dashboard": ["psutil", "aiohttp"],
"serve": ["uvicorn", "pygments", "werkzeug"],
}