[serve] Add atomic delete (#23195)

This commit is contained in:
shrekris-anyscale 2022-03-16 14:13:10 -07:00 committed by GitHub
parent 2bcbe41d54
commit 84b3de6825
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 18 deletions

View file

@ -20,6 +20,7 @@ from typing import (
Type,
Union,
List,
Iterable,
overload,
)
@ -393,9 +394,11 @@ class Client:
self.log_deployment_ready(name, version, url, tags[i])
@_ensure_connected
def delete_deployment(self, name: str) -> None:
ray.get(self._controller.delete_deployment.remote(name))
self._wait_for_deployment_deleted(name)
def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> None:
ray.get(self._controller.delete_deployments.remote(names))
if blocking:
for name in names:
self._wait_for_deployment_deleted(name)
@_ensure_connected
def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]:
@ -1229,7 +1232,7 @@ class Deployment:
@PublicAPI
def delete(self):
"""Delete this deployment."""
return internal_get_global_client().delete_deployment(self._name)
return internal_get_global_client().delete_deployments([self._name])
@PublicAPI
def get_handle(

View file

@ -135,8 +135,10 @@ class Application:
except KeyboardInterrupt:
logger.info("Got SIGINT (KeyboardInterrupt). Removing deployments.")
for deployment in self._deployments.values():
deployment.delete()
deployment_names = [d.name for d in self._deployments.values()]
internal_get_global_client().delete_deployments(
deployment_names, blocking=True
)
if len(serve.list_deployments()) == 0:
logger.info("No deployments left. Shutting down Serve.")
serve.shutdown()

View file

@ -3,7 +3,7 @@ import json
import time
from collections import defaultdict
import os
from typing import Dict, List, Optional, Tuple, Any
from typing import Dict, Iterable, List, Optional, Tuple, Any
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
from copy import copy
@ -367,6 +367,10 @@ class ServeController:
self.endpoint_state.delete_endpoint(name)
return self.deployment_state_manager.delete_deployment(name)
def delete_deployments(self, names: Iterable[str]) -> None:
for name in names:
self.delete_deployment(name)
def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]:
"""Get the current information about a deployment.

View file

@ -40,8 +40,7 @@ def _shared_serve_instance():
def serve_instance(_shared_serve_instance):
yield _shared_serve_instance
# Clear all state between tests to avoid naming collisions.
for deployment in serve.list_deployments().values():
deployment.delete()
_shared_serve_instance.delete_deployments(serve.list_deployments().keys())
# Clear the ServeHandle cache between tests to avoid them piling up.
_shared_serve_instance.handle_cache.clear()
# Clear deployment generation shared state between tests

View file

@ -1,5 +1,4 @@
import asyncio
import time
import os
import requests
@ -9,6 +8,7 @@ import starlette.responses
import ray
from ray import serve
from ray._private.test_utils import SignalActor, wait_for_condition
from ray.serve.api import internal_get_global_client
def test_e2e(serve_instance):
@ -239,14 +239,48 @@ def test_delete_deployment(serve_instance):
function2.deploy()
for _ in range(10):
try:
assert requests.get("http://127.0.0.1:8000/delete").text == "olleh"
break
except AssertionError:
time.sleep(0.5) # Wait for the change to propagate.
else:
assert requests.get("http://127.0.0.1:8000/delete").text == "olleh"
wait_for_condition(
lambda: requests.get("http://127.0.0.1:8000/delete").text == "olleh", timeout=6
)
@pytest.mark.parametrize("blocking", [False, True])
def test_delete_deployment_group(serve_instance, blocking):
@serve.deployment(num_replicas=1)
def f(*args):
return "got f"
@serve.deployment(num_replicas=2)
def g(*args):
return "got g"
# Check redeploying after deletion
for _ in range(2):
f.deploy()
g.deploy()
wait_for_condition(
lambda: requests.get("http://127.0.0.1:8000/f").text == "got f", timeout=5
)
wait_for_condition(
lambda: requests.get("http://127.0.0.1:8000/g").text == "got g", timeout=5
)
# Check idempotence
for _ in range(2):
internal_get_global_client().delete_deployments(
["f", "g"], blocking=blocking
)
wait_for_condition(
lambda: requests.get("http://127.0.0.1:8000/f").status_code == 404,
timeout=5,
)
wait_for_condition(
lambda: requests.get("http://127.0.0.1:8000/g").status_code == 404,
timeout=5,
)
def test_starlette_request(serve_instance):