[Serve] Fix router-worker communication (#5961)

* Half way there, needs the strict queuing fix

* Fix scale down, use callback

* Cleanup

* Address commments

* Comment, nit

* Fix docstring
This commit is contained in:
Simon Mo 2019-11-04 11:29:21 -08:00 committed by GitHub
parent 8485304e83
commit c23eae5998
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 83 additions and 77 deletions

View file

@ -164,7 +164,7 @@ def _start_replica(backend_tag):
ray.get(
runner_handle._ray_serve_setup.remote(
backend_tag, global_state.init_or_get_router(), runner_handle))
runner_handle._ray_serve_main_loop.remote()
runner_handle._ray_serve_fetch.remote()
# Register the worker in config tables as well as metric monitor
global_state.backend_table.add_replica(backend_tag, replica_tag)
@ -181,11 +181,20 @@ def _remove_replica(backend_tag):
replica_tag = global_state.backend_table.remove_replica(backend_tag)
[replica_handle] = ray.get(
global_state.actor_nursery_handle.get_handle.remote(replica_tag))
global_state.init_or_get_metric_monitor().remove_target.remote(
replica_handle)
# Remove the replica from metric monitor.
ray.get(global_state.init_or_get_metric_monitor().remove_target.remote(
replica_handle))
# Remove the replica from actor nursery.
ray.get(
global_state.actor_nursery_handle.remove_handle.remote(replica_tag))
# Remove the replica from router.
# This will also destory the actor handle.
ray.get(global_state.init_or_get_router()
.remove_and_destory_replica.remote(backend_tag, replica_handle))
@_ensure_connected
def scale(backend_tag, num_replicas):

View file

@ -23,11 +23,8 @@ class Query:
class WorkIntent:
def __init__(self, work_object_id=None):
if work_object_id is None:
self.work_object_id = ray.ObjectID.from_random()
else:
self.work_object_id = work_object_id
def __init__(self, replica_handle):
self.replica_handle = replica_handle
class CentralizedQueues:
@ -45,15 +42,15 @@ class CentralizedQueues:
"service-name", request_args, request_kwargs, request_context)
# nothing happens, request is queued.
# returns result ObjectID, which will contains the final result
>>> queue.dequeue_request('backend-1')
>>> queue.dequeue_request('backend-1', replica_handle)
# nothing happens, work intention is queued.
# return work ObjectID, which will contains the future request payload
>>> queue.link('service-name', 'backend-1')
# here the enqueue_requester is matched with worker, request
# data is put into work ObjectID, and the worker processes the request
# here the enqueue_requester is matched with replica, request
# data is put into work ObjectID, and the replica processes the request
# and store the result into result ObjectID
Traffic policy splits the traffic among different workers
Traffic policy splits the traffic among different replicas
probabilistically:
1. When all backends are ready to receive traffic, we will randomly
@ -98,11 +95,23 @@ class CentralizedQueues:
self.flush()
return query.result_object_id.binary()
def dequeue_request(self, backend):
intention = WorkIntent()
def dequeue_request(self, backend, replica_handle):
intention = WorkIntent(replica_handle)
self.workers[backend].append(intention)
self.flush()
return intention.work_object_id.binary()
def remove_and_destory_replica(self, backend, replica_handle):
# NOTE: this function scale by O(#replicas for the backend)
new_queue = deque()
target_id = replica_handle._actor_id
for work_intent in self.workers[backend]:
if work_intent.replica_handle._actor_id != target_id:
new_queue.append(work_intent)
self.workers[backend] = new_queue
replica_handle.__ray_terminate__.remote()
def link(self, service, backend):
logger.debug("Link %s with %s", service, backend)
@ -159,8 +168,7 @@ class CentralizedQueues:
buffer_queue.popleft(),
work_queue.popleft(),
)
ray.worker.global_worker.put_object(
request, work.work_object_id)
work.replica_handle._ray_serve_call.remote(request)
@ray.remote

View file

@ -50,7 +50,7 @@ class RayServeMixin:
_ray_serve_self_handle = None
_ray_serve_router_handle = None
_ray_serve_setup_completed = False
_ray_serve_dequeue_requestr_name = None
_ray_serve_dequeue_requester_name = None
# Work token can be unfullfilled from last iteration.
# This cache will be used to determine whether or not we should
@ -66,7 +66,7 @@ class RayServeMixin:
latency_lst = self._serve_metric_latency_list[:]
self._serve_metric_latency_list = []
my_name = self._ray_serve_dequeue_requestr_name
my_name = self._ray_serve_dequeue_requester_name
return {
"{}_error_counter".format(my_name): {
@ -80,32 +80,20 @@ class RayServeMixin:
}
def _ray_serve_setup(self, my_name, router_handle, my_handle):
self._ray_serve_dequeue_requestr_name = my_name
self._ray_serve_dequeue_requester_name = my_name
self._ray_serve_router_handle = router_handle
self._ray_serve_self_handle = my_handle
self._ray_serve_setup_completed = True
def _ray_serve_main_loop(self):
def _ray_serve_fetch(self):
assert self._ray_serve_setup_completed
# Only retrieve the next task if we have completed previous task.
if self._ray_serve_cached_work_token is None:
work_token = ray.get(
self._ray_serve_router_handle.dequeue_request.remote(
self._ray_serve_dequeue_requestr_name))
else:
work_token = self._ray_serve_cached_work_token
self._ray_serve_router_handle.dequeue_request.remote(
self._ray_serve_dequeue_requester_name,
self._ray_serve_self_handle)
work_token_id = ray.ObjectID(work_token)
ready, not_ready = ray.wait(
[work_token_id], num_returns=1, timeout=0.5)
if len(ready) == 1:
work_item = ray.get(work_token_id)
self._ray_serve_cached_work_token = None
else:
self._ray_serve_cached_work_token = work_token
self._ray_serve_self_handle._ray_serve_main_loop.remote()
return
def _ray_serve_call(self, request):
work_item = request
if work_item.request_context == TaskContext.Web:
serve_context.web = True
@ -132,11 +120,8 @@ class RayServeMixin:
self._serve_metric_latency_list.append(time.time() - start_timestamp)
serve_context.web = False
# The worker finished one unit of work.
# It will now tail recursively schedule the main_loop again.
# TODO(simon): remove tail recursion, ask router to callback instead
self._ray_serve_self_handle._ray_serve_main_loop.remote()
self._ray_serve_fetch()
class TaskRunnerBackend(TaskRunner, RayServeMixin):

View file

@ -1,14 +1,32 @@
import pytest
import ray
from ray.experimental.serve.queues import CentralizedQueues
def test_single_prod_cons_queue(serve_instance):
@pytest.fixture(scope="session")
def task_runner_mock_actor():
@ray.remote
class TaskRunnerMock:
def __init__(self):
self.result = None
def _ray_serve_call(self, request_item):
self.result = request_item
def get_recent_call(self):
return self.result
actor = TaskRunnerMock.remote()
yield actor
def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = CentralizedQueues()
q.link("svc", "backend")
result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
work_object_id = q.dequeue_request("backend")
got_work = ray.get(ray.ObjectID(work_object_id))
q.dequeue_request("backend", task_runner_mock_actor)
got_work = ray.get(task_runner_mock_actor.get_recent_call.remote())
assert got_work.request_args == 1
assert got_work.request_kwargs == "kwargs"
@ -16,27 +34,27 @@ def test_single_prod_cons_queue(serve_instance):
assert ray.get(ray.ObjectID(result_object_id)) == 2
def test_alter_backend(serve_instance):
def test_alter_backend(serve_instance, task_runner_mock_actor):
q = CentralizedQueues()
q.set_traffic("svc", {"backend-1": 1})
result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
work_object_id = q.dequeue_request("backend-1")
got_work = ray.get(ray.ObjectID(work_object_id))
q.dequeue_request("backend-1", task_runner_mock_actor)
got_work = ray.get(task_runner_mock_actor.get_recent_call.remote())
assert got_work.request_args == 1
ray.worker.global_worker.put_object(2, got_work.result_object_id)
assert ray.get(ray.ObjectID(result_object_id)) == 2
q.set_traffic("svc", {"backend-2": 1})
result_object_id = q.enqueue_request("svc", 1, "kwargs", None)
work_object_id = q.dequeue_request("backend-2")
got_work = ray.get(ray.ObjectID(work_object_id))
q.dequeue_request("backend-2", task_runner_mock_actor)
got_work = ray.get(task_runner_mock_actor.get_recent_call.remote())
assert got_work.request_args == 1
ray.worker.global_worker.put_object(2, got_work.result_object_id)
assert ray.get(ray.ObjectID(result_object_id)) == 2
def test_split_traffic(serve_instance):
def test_split_traffic(serve_instance, task_runner_mock_actor):
q = CentralizedQueues()
q.set_traffic("svc", {"backend-1": 0.5, "backend-2": 0.5})
@ -44,31 +62,17 @@ def test_split_traffic(serve_instance):
# single queue is 0.5^20 ~ 1-6
for _ in range(20):
q.enqueue_request("svc", 1, "kwargs", None)
work_object_id_1 = q.dequeue_request("backend-1")
work_object_id_2 = q.dequeue_request("backend-2")
q.dequeue_request("backend-1", task_runner_mock_actor)
result_one = ray.get(task_runner_mock_actor.get_recent_call.remote())
q.dequeue_request("backend-2", task_runner_mock_actor)
result_two = ray.get(task_runner_mock_actor.get_recent_call.remote())
got_work = ray.get(
[ray.ObjectID(work_object_id_1),
ray.ObjectID(work_object_id_2)])
got_work = [result_one, result_two]
assert [g.request_args for g in got_work] == [1, 1]
def test_probabilities(serve_instance):
def test_queue_remove_replicas(serve_instance, task_runner_mock_actor):
q = CentralizedQueues()
[q.enqueue_request("svc", 1, "kwargs", None) for i in range(100)]
work_object_id_1_s = [
ray.ObjectID(q.dequeue_request("backend-1")) for i in range(100)
]
work_object_id_2_s = [
ray.ObjectID(q.dequeue_request("backend-2")) for i in range(100)
]
q.set_traffic("svc", {"backend-1": 0.1, "backend-2": 0.9})
backend_1_ready_object_ids, _ = ray.wait(
work_object_id_1_s, num_returns=100, timeout=0.0)
backend_2_ready_object_ids, _ = ray.wait(
work_object_id_2_s, num_returns=100, timeout=0.0)
assert len(backend_1_ready_object_ids) < len(backend_2_ready_object_ids)
q.dequeue_request("backend", task_runner_mock_actor)
q.remove_and_destory_replica("backend", task_runner_mock_actor)
assert len(q.workers["backend"]) == 0

View file

@ -32,7 +32,7 @@ def test_runner_actor(serve_instance):
runner = TaskRunnerActor.remote(echo)
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
runner._ray_serve_main_loop.remote()
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@ -67,7 +67,7 @@ def test_ray_serve_mixin(serve_instance):
runner = CustomActor.remote(3)
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
runner._ray_serve_main_loop.remote()
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@ -95,7 +95,7 @@ def test_task_runner_check_context(serve_instance):
runner = TaskRunnerActor.remote(echo)
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
runner._ray_serve_main_loop.remote()
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
result_token = ray.ObjectID(