mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[serve] Call serve.init() before initializing backends (#7922)
This commit is contained in:
parent
1be87c7fbb
commit
85481d635d
5 changed files with 56 additions and 16 deletions
|
@ -244,6 +244,8 @@ def create_backend(func_or_class,
|
|||
class CustomActor(RayServeMixin, func_or_class):
|
||||
@wraps(func_or_class.__init__)
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Initialize serve so it can be used in backends.
|
||||
init()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
arg_list = actor_init_args
|
||||
|
|
|
@ -16,7 +16,7 @@ class ServeMaster:
|
|||
"""Initialize and store all actor handles.
|
||||
|
||||
Note:
|
||||
This actor is necessary because ray will destory actors when the
|
||||
This actor is necessary because ray will destroy actors when the
|
||||
original actor handle goes out of scope (when driver exit). Therefore
|
||||
we need to initialize and store actor handles in a seperate actor.
|
||||
"""
|
||||
|
@ -72,7 +72,13 @@ class ServeMaster:
|
|||
def _list_replicas(self, backend_tag):
|
||||
return self.backend_table.list_replicas(backend_tag)
|
||||
|
||||
def scale_replicas(self, backend_tag, num_replicas):
|
||||
async def scale_replicas(self, backend_tag, num_replicas):
|
||||
"""Scale the given backend to the number of replicas.
|
||||
|
||||
This requires the master actor to be an async actor because we wait
|
||||
synchronously for backends to start up and they may make calls into
|
||||
the master actor while initializing (e.g., by calling get_handle()).
|
||||
"""
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert num_replicas >= 0, ("Number of replicas must be"
|
||||
|
@ -83,12 +89,12 @@ class ServeMaster:
|
|||
|
||||
if delta_num_replicas > 0:
|
||||
for _ in range(delta_num_replicas):
|
||||
self._start_backend_replica(backend_tag)
|
||||
await self._start_backend_replica(backend_tag)
|
||||
elif delta_num_replicas < 0:
|
||||
for _ in range(-delta_num_replicas):
|
||||
self._remove_backend_replica(backend_tag)
|
||||
|
||||
def _start_backend_replica(self, backend_tag):
|
||||
async def _start_backend_replica(self, backend_tag):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
|
||||
|
@ -105,10 +111,10 @@ class ServeMaster:
|
|||
self.tag_to_actor_handles[replica_tag] = runner_handle
|
||||
|
||||
# Set up the worker.
|
||||
ray.get(
|
||||
runner_handle._ray_serve_setup.remote(backend_tag,
|
||||
|
||||
await runner_handle._ray_serve_setup.remote(backend_tag,
|
||||
self.get_router()[0],
|
||||
runner_handle))
|
||||
runner_handle)
|
||||
ray.get(runner_handle._ray_serve_fetch.remote())
|
||||
|
||||
# Register the worker in config tables and metric monitor.
|
||||
|
@ -134,7 +140,7 @@ class ServeMaster:
|
|||
# This will also destroy the actor handle.
|
||||
[router] = self.get_router()
|
||||
ray.get(
|
||||
router.remove_and_destory_replica.remote(backend_tag,
|
||||
router.remove_and_destroy_replica.remote(backend_tag,
|
||||
replica_handle))
|
||||
|
||||
def get_all_handles(self):
|
||||
|
@ -175,7 +181,8 @@ class ServeMaster:
|
|||
self.route_table.list_service(
|
||||
include_methods=True, include_headless=False)))
|
||||
|
||||
def create_backend(self, backend_tag, creator, backend_config, arg_list):
|
||||
async def create_backend(self, backend_tag, creator, backend_config,
|
||||
arg_list):
|
||||
backend_config_dict = dict(backend_config)
|
||||
|
||||
# Save creator which starts replicas.
|
||||
|
@ -192,9 +199,10 @@ class ServeMaster:
|
|||
[router] = self.get_router()
|
||||
ray.get(
|
||||
router.set_backend_config.remote(backend_tag, backend_config_dict))
|
||||
self.scale_replicas(backend_tag, backend_config_dict["num_replicas"])
|
||||
await self.scale_replicas(backend_tag,
|
||||
backend_config_dict["num_replicas"])
|
||||
|
||||
def set_backend_config(self, backend_tag, backend_config):
|
||||
async def set_backend_config(self, backend_tag, backend_config):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(backend_config,
|
||||
|
@ -222,10 +230,11 @@ class ServeMaster:
|
|||
for k in BackendConfig.restart_on_change_fields)
|
||||
if need_to_restart_replicas:
|
||||
# Kill all the replicas for restarting with new configurations.
|
||||
self.scale_replicas(backend_tag, 0)
|
||||
await self.scale_replicas(backend_tag, 0)
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
self.scale_replicas(backend_tag, backend_config_dict["num_replicas"])
|
||||
await self.scale_replicas(backend_tag,
|
||||
backend_config_dict["num_replicas"])
|
||||
|
||||
def get_backend_config(self, backend_tag):
|
||||
assert (backend_tag in self.backend_table.list_backends()
|
||||
|
|
|
@ -203,7 +203,7 @@ class CentralizedQueues:
|
|||
await self.worker_queues[backend].put(replica_handle)
|
||||
await self.flush()
|
||||
|
||||
async def remove_and_destory_replica(self, backend, replica_handle):
|
||||
async def remove_and_destroy_replica(self, backend, replica_handle):
|
||||
# We need this lock because we modify worker_queue here.
|
||||
async with self.flush_lock:
|
||||
old_queue = self.worker_queues[backend]
|
||||
|
|
29
python/ray/serve/tests/test_handle.py
Normal file
29
python/ray/serve/tests/test_handle.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import ray
|
||||
from ray import serve
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_handle_in_endpoint(serve_instance):
|
||||
serve.init()
|
||||
|
||||
class Endpoint1:
|
||||
def __call__(self, flask_request):
|
||||
return "hello"
|
||||
|
||||
class Endpoint2:
|
||||
def __init__(self):
|
||||
self.handle = serve.get_handle("endpoint1", missing_ok=True)
|
||||
|
||||
def __call__(self):
|
||||
return ray.get(self.handle.remote())
|
||||
|
||||
serve.create_endpoint("endpoint1", "/endpoint1", methods=["GET", "POST"])
|
||||
serve.create_backend(Endpoint1, "endpoint1:v0")
|
||||
serve.link("endpoint1", "endpoint1:v0")
|
||||
|
||||
serve.create_endpoint("endpoint2", "/endpoint2", methods=["GET", "POST"])
|
||||
serve.create_backend(Endpoint2, "endpoint2:v0")
|
||||
serve.link("endpoint2", "endpoint2:v0")
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello"
|
|
@ -188,5 +188,5 @@ async def test_queue_remove_replicas(serve_instance):
|
|||
temp_actor = make_task_runner_mock()
|
||||
q = RandomPolicyQueue()
|
||||
await q.dequeue_request("backend", temp_actor)
|
||||
await q.remove_and_destory_replica("backend", temp_actor)
|
||||
await q.remove_and_destroy_replica("backend", temp_actor)
|
||||
assert q.worker_queues["backend"].qsize() == 0
|
||||
|
|
Loading…
Add table
Reference in a new issue