[serve] Call serve.init() before initializing backends (#7922)

This commit is contained in:
Edward Oakes 2020-04-07 17:22:52 -05:00 committed by GitHub
parent 1be87c7fbb
commit 85481d635d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 16 deletions

View file

@ -244,6 +244,8 @@ def create_backend(func_or_class,
class CustomActor(RayServeMixin, func_or_class):
@wraps(func_or_class.__init__)
def __init__(self, *args, **kwargs):
# Initialize serve so it can be used in backends.
init()
super().__init__(*args, **kwargs)
arg_list = actor_init_args

View file

@ -16,7 +16,7 @@ class ServeMaster:
"""Initialize and store all actor handles.
Note:
This actor is necessary because ray will destory actors when the
This actor is necessary because ray will destroy actors when the
original actor handle goes out of scope (when driver exit). Therefore
we need to initialize and store actor handles in a seperate actor.
"""
@ -72,7 +72,13 @@ class ServeMaster:
def _list_replicas(self, backend_tag):
return self.backend_table.list_replicas(backend_tag)
def scale_replicas(self, backend_tag, num_replicas):
async def scale_replicas(self, backend_tag, num_replicas):
"""Scale the given backend to the number of replicas.
This requires the master actor to be an async actor because we wait
synchronously for backends to start up and they may make calls into
the master actor while initializing (e.g., by calling get_handle()).
"""
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert num_replicas >= 0, ("Number of replicas must be"
@ -83,12 +89,12 @@ class ServeMaster:
if delta_num_replicas > 0:
for _ in range(delta_num_replicas):
self._start_backend_replica(backend_tag)
await self._start_backend_replica(backend_tag)
elif delta_num_replicas < 0:
for _ in range(-delta_num_replicas):
self._remove_backend_replica(backend_tag)
def _start_backend_replica(self, backend_tag):
async def _start_backend_replica(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
@ -105,10 +111,10 @@ class ServeMaster:
self.tag_to_actor_handles[replica_tag] = runner_handle
# Set up the worker.
ray.get(
runner_handle._ray_serve_setup.remote(backend_tag,
self.get_router()[0],
runner_handle))
await runner_handle._ray_serve_setup.remote(backend_tag,
self.get_router()[0],
runner_handle)
ray.get(runner_handle._ray_serve_fetch.remote())
# Register the worker in config tables and metric monitor.
@ -134,7 +140,7 @@ class ServeMaster:
# This will also destroy the actor handle.
[router] = self.get_router()
ray.get(
router.remove_and_destory_replica.remote(backend_tag,
router.remove_and_destroy_replica.remote(backend_tag,
replica_handle))
def get_all_handles(self):
@ -175,7 +181,8 @@ class ServeMaster:
self.route_table.list_service(
include_methods=True, include_headless=False)))
def create_backend(self, backend_tag, creator, backend_config, arg_list):
async def create_backend(self, backend_tag, creator, backend_config,
arg_list):
backend_config_dict = dict(backend_config)
# Save creator which starts replicas.
@ -192,9 +199,10 @@ class ServeMaster:
[router] = self.get_router()
ray.get(
router.set_backend_config.remote(backend_tag, backend_config_dict))
self.scale_replicas(backend_tag, backend_config_dict["num_replicas"])
await self.scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
def set_backend_config(self, backend_tag, backend_config):
async def set_backend_config(self, backend_tag, backend_config):
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(backend_config,
@ -222,10 +230,11 @@ class ServeMaster:
for k in BackendConfig.restart_on_change_fields)
if need_to_restart_replicas:
# Kill all the replicas for restarting with new configurations.
self.scale_replicas(backend_tag, 0)
await self.scale_replicas(backend_tag, 0)
# Scale the replicas with the new configuration.
self.scale_replicas(backend_tag, backend_config_dict["num_replicas"])
await self.scale_replicas(backend_tag,
backend_config_dict["num_replicas"])
def get_backend_config(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()

View file

@ -203,7 +203,7 @@ class CentralizedQueues:
await self.worker_queues[backend].put(replica_handle)
await self.flush()
async def remove_and_destory_replica(self, backend, replica_handle):
async def remove_and_destroy_replica(self, backend, replica_handle):
# We need this lock because we modify worker_queue here.
async with self.flush_lock:
old_queue = self.worker_queues[backend]

View file

@ -0,0 +1,29 @@
import ray
from ray import serve
import requests
def test_handle_in_endpoint(serve_instance):
serve.init()
class Endpoint1:
def __call__(self, flask_request):
return "hello"
class Endpoint2:
def __init__(self):
self.handle = serve.get_handle("endpoint1", missing_ok=True)
def __call__(self):
return ray.get(self.handle.remote())
serve.create_endpoint("endpoint1", "/endpoint1", methods=["GET", "POST"])
serve.create_backend(Endpoint1, "endpoint1:v0")
serve.link("endpoint1", "endpoint1:v0")
serve.create_endpoint("endpoint2", "/endpoint2", methods=["GET", "POST"])
serve.create_backend(Endpoint2, "endpoint2:v0")
serve.link("endpoint2", "endpoint2:v0")
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"

View file

@ -188,5 +188,5 @@ async def test_queue_remove_replicas(serve_instance):
temp_actor = make_task_runner_mock()
q = RandomPolicyQueue()
await q.dequeue_request("backend", temp_actor)
await q.remove_and_destory_replica("backend", temp_actor)
await q.remove_and_destroy_replica("backend", temp_actor)
assert q.worker_queues["backend"].qsize() == 0