mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] Support scale replica down to 0 (#24892)
This commit is contained in:
parent
e4ceae19ef
commit
b024a9543e
18 changed files with 271 additions and 70 deletions
|
@ -447,6 +447,11 @@ def deployment(
|
|||
Deployment
|
||||
"""
|
||||
|
||||
# Num of replicas should not be 0.
|
||||
# TODO(Sihan) seperate num_replicas attribute from internal and api
|
||||
if num_replicas == 0:
|
||||
raise ValueError("num_replicas is expected to larger than 0")
|
||||
|
||||
if num_replicas is not None and _autoscaling_config is not None:
|
||||
raise ValueError(
|
||||
"Manually setting num_replicas is not allowed when "
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
import threading
|
||||
import bisect
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from threading import Event
|
||||
from typing import Type
|
||||
import time
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import SERVE_LOGGER_NAME
|
||||
|
||||
|
||||
logger = logging.getLogger(SERVE_LOGGER_NAME)
|
||||
|
||||
|
||||
def start_metrics_pusher(
|
||||
interval_s: float,
|
||||
collection_callback: Callable[[], Dict[str, float]],
|
||||
controller_handle,
|
||||
metrics_process_func: Callable[[Dict[str, float], float], ray.ObjectRef],
|
||||
stop_event: Type[Event] = None,
|
||||
):
|
||||
"""Start a background thread to push metrics to controller.
|
||||
|
||||
|
@ -19,44 +27,62 @@ def start_metrics_pusher(
|
|||
consistently metrics delivery. Python GIL will ensure that this thread gets
|
||||
fair timeshare to execute and run.
|
||||
|
||||
Stop_event is passed in only when a RayServeHandle calls this function to
|
||||
push metrics for scale-to-zero. stop_event is set either when the handle
|
||||
is garbage collected or when the Serve application shuts down.
|
||||
|
||||
Args:
|
||||
interval_s(float): the push interval.
|
||||
collection_callback: a callable that returns the metric data points to
|
||||
be sent to the the controller. The collection callback should take
|
||||
no argument and returns a dictionary of str_key -> float_value.
|
||||
controller_handle: actor handle to Serve controller.
|
||||
metrics_process_func: actor handle function.
|
||||
stop_event: the backgroupd thread will be closed when this event is set
|
||||
Returns:
|
||||
timer: The background thread created by this function to push
|
||||
metrics to the controller
|
||||
"""
|
||||
|
||||
def send_once():
|
||||
data = collection_callback()
|
||||
# TODO(simon): maybe wait for ack or handle controller failure?
|
||||
return controller_handle.record_autoscaling_metrics.remote(
|
||||
data=data, send_timestamp=time.time()
|
||||
)
|
||||
|
||||
def send_forever():
|
||||
# TODO(simon): maybe wait for ack or handle controller failure?
|
||||
return metrics_process_func(data=data, send_timestamp=time.time())
|
||||
|
||||
def send_forever(stop_event):
|
||||
last_ref: Optional[ray.ObjectRef] = None
|
||||
last_send_succeeded: bool = True
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
if stop_event and stop_event.is_set():
|
||||
return
|
||||
|
||||
if ray.is_initialized():
|
||||
try:
|
||||
if last_ref:
|
||||
ready_refs, _ = ray.wait([last_ref], timeout=0)
|
||||
last_send_succeeded = len(ready_refs) == 1
|
||||
if last_send_succeeded:
|
||||
last_ref = send_once()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Autoscaling metrics pusher thread "
|
||||
"is failing to send metrics to the controller "
|
||||
f": {e}"
|
||||
)
|
||||
|
||||
duration_s = time.time() - start
|
||||
remaining_time = interval_s - duration_s
|
||||
if remaining_time > 0:
|
||||
time.sleep(remaining_time)
|
||||
|
||||
timer = threading.Thread(target=send_forever)
|
||||
timer = threading.Thread(target=send_forever, args=[stop_event])
|
||||
# Making this a daemon thread so it doesn't leak upon shutdown, and it
|
||||
# doesn't need to block the replica's shutdown.
|
||||
timer.setDaemon(True)
|
||||
timer.start()
|
||||
return timer
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
|
@ -85,6 +111,19 @@ class InMemoryMetricsStore:
|
|||
# Using in-sort to insert while maintaining sorted ordering.
|
||||
bisect.insort(a=self.data[name], x=TimeStampedValue(timestamp, value))
|
||||
|
||||
def _get_datapoints(self, key: str, window_start_timestamp_s: float) -> List[float]:
|
||||
"""Get all data points given key after window_start_timestamp_s"""
|
||||
|
||||
datapoints = self.data[key]
|
||||
|
||||
idx = bisect.bisect(
|
||||
a=datapoints,
|
||||
x=TimeStampedValue(
|
||||
timestamp=window_start_timestamp_s, value=0 # dummy value
|
||||
),
|
||||
)
|
||||
return datapoints[idx:]
|
||||
|
||||
def window_average(
|
||||
self, key: str, window_start_timestamp_s: float, do_compact: bool = True
|
||||
) -> Optional[float]:
|
||||
|
@ -102,15 +141,7 @@ class InMemoryMetricsStore:
|
|||
The average of all the datapoints for the key on and after time
|
||||
window_start_timestamp_s, or None if there are no such points.
|
||||
"""
|
||||
datapoints = self.data[key]
|
||||
|
||||
idx = bisect.bisect(
|
||||
a=datapoints,
|
||||
x=TimeStampedValue(
|
||||
timestamp=window_start_timestamp_s, value=0 # dummy value
|
||||
),
|
||||
)
|
||||
points_after_idx = datapoints[idx:]
|
||||
points_after_idx = self._get_datapoints(key, window_start_timestamp_s)
|
||||
|
||||
if do_compact:
|
||||
self.data[key] = points_after_idx
|
||||
|
@ -118,3 +149,19 @@ class InMemoryMetricsStore:
|
|||
if len(points_after_idx) == 0:
|
||||
return
|
||||
return sum(point.value for point in points_after_idx) / len(points_after_idx)
|
||||
|
||||
def max(self, key: str, window_start_timestamp_s: float):
|
||||
"""Perform a max operation for metric `key`.
|
||||
|
||||
Args:
|
||||
key(str): the metric name.
|
||||
window_start_timestamp_s(float): the unix epoch timestamp for the
|
||||
start of the window. The computed average will use all datapoints
|
||||
from this timestamp until now.
|
||||
Returns:
|
||||
Max value of the data points for the key on and after time
|
||||
window_start_timestamp_s, or None if there are no such points.
|
||||
"""
|
||||
points_after_idx = self._get_datapoints(key, window_start_timestamp_s)
|
||||
|
||||
return max((point.value for point in points_after_idx), default=None)
|
||||
|
|
|
@ -68,7 +68,10 @@ class AutoscalingPolicy:
|
|||
|
||||
@abstractmethod
|
||||
def get_decision_num_replicas(
|
||||
self, current_num_ongoing_requests: List[float], curr_target_num_replicas: int
|
||||
self,
|
||||
curr_target_num_replicas: int,
|
||||
current_num_ongoing_requests: List[float],
|
||||
current_handle_queued_queries: float,
|
||||
) -> int:
|
||||
"""Make a decision to scale replicas.
|
||||
|
||||
|
@ -77,6 +80,9 @@ class AutoscalingPolicy:
|
|||
ongoing requests for each replica.
|
||||
curr_target_num_replicas: The number of replicas that the
|
||||
deployment is currently trying to scale to.
|
||||
current_handle_queued_queries : The number of handle queued queries,
|
||||
if there are multiple handles, the max number of queries at
|
||||
a single handle should be passed in
|
||||
|
||||
Returns:
|
||||
int: The new number of replicas to scale to.
|
||||
|
@ -119,9 +125,16 @@ class BasicAutoscalingPolicy(AutoscalingPolicy):
|
|||
self.decision_counter = 0
|
||||
|
||||
def get_decision_num_replicas(
|
||||
self, current_num_ongoing_requests: List[float], curr_target_num_replicas: int
|
||||
self,
|
||||
curr_target_num_replicas: int,
|
||||
current_num_ongoing_requests: List[float],
|
||||
current_handle_queued_queries: float,
|
||||
) -> int:
|
||||
|
||||
if len(current_num_ongoing_requests) == 0:
|
||||
# When 0 replica and queries queued, scale up the replicas
|
||||
if current_handle_queued_queries > 0:
|
||||
return max(1, curr_target_num_replicas)
|
||||
return curr_target_num_replicas
|
||||
|
||||
decision_num_replicas = curr_target_num_replicas
|
||||
|
|
|
@ -120,6 +120,12 @@ class ServeControllerClient:
|
|||
Shuts down all processes and deletes all state associated with the
|
||||
instance.
|
||||
"""
|
||||
|
||||
# Shut down handles
|
||||
for k in list(self.handle_cache):
|
||||
self.handle_cache[k].stop_metrics_pusher()
|
||||
del self.handle_cache[k]
|
||||
|
||||
if ray.is_initialized() and not self._shutdown:
|
||||
ray.get(self._controller.shutdown.remote())
|
||||
self._wait_for_deployments_shutdown()
|
||||
|
|
|
@ -108,7 +108,7 @@ class DeploymentConfig(BaseModel):
|
|||
replica's health check before marking it unhealthy.
|
||||
"""
|
||||
|
||||
num_replicas: PositiveInt = 1
|
||||
num_replicas: NonNegativeInt = 1
|
||||
max_concurrent_queries: Optional[int] = None
|
||||
user_config: Any = None
|
||||
|
||||
|
|
|
@ -96,6 +96,9 @@ ANONYMOUS_NAMESPACE_PATTERN = re.compile(
|
|||
"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89aAbB][a-f0-9]{3}-[a-f0-9]{12}"
|
||||
)
|
||||
|
||||
# Handle metric push interval. (This interval will affect the cold start time period)
|
||||
HANDLE_METRIC_PUSH_INTERVAL_S = 10
|
||||
|
||||
|
||||
class ServeHandleType(str, Enum):
|
||||
SYNC = "SYNC"
|
||||
|
|
|
@ -121,6 +121,7 @@ class ServeController:
|
|||
|
||||
# TODO(simon): move autoscaling related stuff into a manager.
|
||||
self.autoscaling_metrics_store = InMemoryMetricsStore()
|
||||
self.handle_metrics_store = InMemoryMetricsStore()
|
||||
|
||||
asyncio.get_event_loop().create_task(self.run_control_loop())
|
||||
|
||||
|
@ -131,6 +132,9 @@ class ServeController:
|
|||
def record_autoscaling_metrics(self, data: Dict[str, float], send_timestamp: float):
|
||||
self.autoscaling_metrics_store.add_metrics_point(data, send_timestamp)
|
||||
|
||||
def record_handle_metrics(self, data: Dict[str, float], send_timestamp: float):
|
||||
self.handle_metrics_store.add_metrics_point(data, send_timestamp)
|
||||
|
||||
def _dump_autoscaling_metrics_for_testing(self):
|
||||
return self.autoscaling_metrics_store.data
|
||||
|
||||
|
@ -201,15 +205,25 @@ class ServeController:
|
|||
if num_ongoing_requests is not None:
|
||||
current_num_ongoing_requests.append(num_ongoing_requests)
|
||||
|
||||
if len(current_num_ongoing_requests) == 0:
|
||||
continue
|
||||
current_handle_queued_queries = self.handle_metrics_store.max(
|
||||
deployment_name,
|
||||
time.time() - autoscaling_policy.config.look_back_period_s,
|
||||
)
|
||||
|
||||
if current_handle_queued_queries is None:
|
||||
current_handle_queued_queries = 0
|
||||
|
||||
new_deployment_config = deployment_config.copy()
|
||||
|
||||
decision_num_replicas = autoscaling_policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=current_num_ongoing_requests,
|
||||
curr_target_num_replicas=deployment_config.num_replicas,
|
||||
current_num_ongoing_requests=current_num_ongoing_requests,
|
||||
current_handle_queued_queries=current_handle_queued_queries,
|
||||
)
|
||||
|
||||
if decision_num_replicas == deployment_config.num_replicas:
|
||||
continue
|
||||
|
||||
new_deployment_config.num_replicas = decision_num_replicas
|
||||
|
||||
new_deployment_info = copy(deployment_info)
|
||||
|
|
|
@ -295,6 +295,16 @@ class Deployment:
|
|||
unchanged from the existing deployment.
|
||||
"""
|
||||
new_config = self._config.copy()
|
||||
|
||||
if num_replicas is not None and _autoscaling_config is not None:
|
||||
raise ValueError(
|
||||
"Manually setting num_replicas is not allowed when "
|
||||
"_autoscaling_config is provided."
|
||||
)
|
||||
|
||||
if num_replicas == 0:
|
||||
raise ValueError("num_replicas is expected to larger than 0")
|
||||
|
||||
if num_replicas is not None:
|
||||
new_config.num_replicas = num_replicas
|
||||
if user_config is not None:
|
||||
|
|
|
@ -943,6 +943,7 @@ class DeploymentState:
|
|||
self._curr_status_info: DeploymentStatusInfo = DeploymentStatusInfo(
|
||||
self._name, DeploymentStatus.UPDATING
|
||||
)
|
||||
self._deleting = False
|
||||
|
||||
def get_target_state_checkpoint_data(self):
|
||||
"""
|
||||
|
@ -1044,6 +1045,7 @@ class DeploymentState:
|
|||
|
||||
else:
|
||||
self._target_replicas = 0
|
||||
self._deleting = True
|
||||
|
||||
self._curr_status_info = DeploymentStatusInfo(
|
||||
self._name, DeploymentStatus.UPDATING
|
||||
|
@ -1331,11 +1333,11 @@ class DeploymentState:
|
|||
== 0
|
||||
):
|
||||
# Check for deleting.
|
||||
if target_replica_count == 0 and all_running_replica_cnt == 0:
|
||||
if self._deleting and all_running_replica_cnt == 0:
|
||||
return True
|
||||
|
||||
# Check for a non-zero number of deployments.
|
||||
elif target_replica_count == running_at_target_version_replica_cnt:
|
||||
if target_replica_count == running_at_target_version_replica_cnt:
|
||||
self._curr_status_info = DeploymentStatusInfo(
|
||||
self._name, DeploymentStatus.HEALTHY
|
||||
)
|
||||
|
|
|
@ -16,6 +16,8 @@ from ray.serve.utils import (
|
|||
get_random_letters,
|
||||
DEFAULT,
|
||||
)
|
||||
from ray.serve.autoscaling_metrics import start_metrics_pusher
|
||||
from ray.serve.constants import HANDLE_METRIC_PUSH_INTERVAL_S
|
||||
from ray.serve.router import Router, RequestMetadata
|
||||
from ray.util import metrics
|
||||
|
||||
|
@ -104,6 +106,17 @@ class RayServeHandle:
|
|||
|
||||
self.router: Router = _router or self._make_router()
|
||||
|
||||
self._stop_event = threading.Event()
|
||||
self._pusher = start_metrics_pusher(
|
||||
interval_s=HANDLE_METRIC_PUSH_INTERVAL_S,
|
||||
collection_callback=self._collect_handle_queue_metrics,
|
||||
metrics_process_func=self.controller_handle.record_handle_metrics.remote,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
def _collect_handle_queue_metrics(self) -> Dict[str, int]:
|
||||
return {self.deployment_name: self.router.get_num_queued_queries()}
|
||||
|
||||
def _make_router(self) -> Router:
|
||||
return Router(
|
||||
self.controller_handle,
|
||||
|
@ -111,6 +124,10 @@ class RayServeHandle:
|
|||
event_loop=asyncio.get_event_loop(),
|
||||
)
|
||||
|
||||
def stop_metrics_pusher(self):
|
||||
self._stop_event.set()
|
||||
self._pusher.join()
|
||||
|
||||
@property
|
||||
def is_polling(self) -> bool:
|
||||
"""Whether this handle is actively polling for replica updates."""
|
||||
|
@ -201,6 +218,9 @@ class RayServeHandle:
|
|||
def __getattr__(self, name):
|
||||
return self.options(method_name=name)
|
||||
|
||||
def __del__(self):
|
||||
self.stop_metrics_pusher()
|
||||
|
||||
|
||||
class RayServeSyncHandle(RayServeHandle):
|
||||
@property
|
||||
|
|
|
@ -317,11 +317,12 @@ class RayServeReplica:
|
|||
self._shutdown_wait_loop_s = deployment_config.graceful_shutdown_wait_loop_s
|
||||
|
||||
if deployment_config.autoscaling_config:
|
||||
process_remote_func = controller_handle.record_autoscaling_metrics.remote
|
||||
config = deployment_config.autoscaling_config
|
||||
start_metrics_pusher(
|
||||
interval_s=config.metrics_interval_s,
|
||||
collection_callback=self._collect_autoscaling_metrics,
|
||||
controller_handle=controller_handle,
|
||||
metrics_process_func=process_remote_func,
|
||||
)
|
||||
|
||||
# NOTE(edoakes): we used to recommend that users use the "ray" logger
|
||||
|
|
|
@ -205,6 +205,9 @@ class Router:
|
|||
call_in_event_loop=event_loop,
|
||||
)
|
||||
|
||||
def get_num_queued_queries(self):
|
||||
return self._replica_set.num_queued_queries
|
||||
|
||||
async def assign_request(
|
||||
self,
|
||||
request_meta: RequestMetadata,
|
||||
|
|
|
@ -311,6 +311,11 @@ def test_delete_deployment_group(serve_instance, blocking):
|
|||
timeout=5,
|
||||
)
|
||||
|
||||
wait_for_condition(
|
||||
lambda: len(serve_instance.list_deployments()) == 0,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
|
||||
def test_starlette_request(serve_instance):
|
||||
@serve.deployment(name="api")
|
||||
|
|
|
@ -12,6 +12,7 @@ class TestInMemoryMetricsStore:
|
|||
s.add_metrics_point({"m1": 1}, timestamp=1)
|
||||
s.add_metrics_point({"m1": 2}, timestamp=2)
|
||||
assert s.window_average("m1", window_start_timestamp_s=0) == 1.5
|
||||
assert s.max("m1", window_start_timestamp_s=0) == 2
|
||||
|
||||
def test_out_of_order_insert(self):
|
||||
s = InMemoryMetricsStore()
|
||||
|
@ -21,10 +22,12 @@ class TestInMemoryMetricsStore:
|
|||
s.add_metrics_point({"m1": 2}, timestamp=2)
|
||||
s.add_metrics_point({"m1": 4}, timestamp=4)
|
||||
assert s.window_average("m1", window_start_timestamp_s=0) == 3
|
||||
assert s.max("m1", window_start_timestamp_s=0) == 5
|
||||
|
||||
def test_window_start_timestamp(self):
|
||||
s = InMemoryMetricsStore()
|
||||
assert s.window_average("m1", window_start_timestamp_s=0) is None
|
||||
assert s.max("m1", window_start_timestamp_s=0) is None
|
||||
|
||||
s.add_metrics_point({"m1": 1}, timestamp=2)
|
||||
assert s.window_average("m1", window_start_timestamp_s=0) == 1
|
||||
|
@ -51,7 +54,8 @@ class TestInMemoryMetricsStore:
|
|||
s.add_metrics_point({"m1": 1, "m2": -1}, timestamp=1)
|
||||
s.add_metrics_point({"m1": 2, "m2": -2}, timestamp=2)
|
||||
assert s.window_average("m1", window_start_timestamp_s=0) == 1.5
|
||||
assert s.window_average("m2", window_start_timestamp_s=0) == -1.5
|
||||
assert s.max("m1", window_start_timestamp_s=0) == 2
|
||||
assert s.max("m2", window_start_timestamp_s=0) == -1
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
|
|
|
@ -170,7 +170,8 @@ def get_deployment_start_time(controller: ServeController, deployment: Deploymen
|
|||
return deployment_info.start_time_ms
|
||||
|
||||
|
||||
def test_e2e_basic_scale_up_down(serve_instance):
|
||||
@pytest.mark.parametrize("min_replicas", [0, 1])
|
||||
def test_e2e_basic_scale_up_down(min_replicas, serve_instance):
|
||||
"""Send 100 requests and check that we autoscale up, and then back down."""
|
||||
|
||||
signal = SignalActor.remote()
|
||||
|
@ -178,7 +179,7 @@ def test_e2e_basic_scale_up_down(serve_instance):
|
|||
@serve.deployment(
|
||||
_autoscaling_config={
|
||||
"metrics_interval_s": 0.1,
|
||||
"min_replicas": 1,
|
||||
"min_replicas": min_replicas,
|
||||
"max_replicas": 2,
|
||||
"look_back_period_s": 0.2,
|
||||
"downscale_delay_s": 0,
|
||||
|
@ -206,7 +207,7 @@ def test_e2e_basic_scale_up_down(serve_instance):
|
|||
signal.send.remote()
|
||||
|
||||
# As the queue is drained, we should scale back down.
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) <= 1)
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) <= min_replicas)
|
||||
|
||||
# Make sure start time did not change for the deployment
|
||||
assert get_deployment_start_time(controller, A) == start_time
|
||||
|
@ -245,7 +246,7 @@ def test_upscale_downscale_delay():
|
|||
downscale_delay_s = 600.0
|
||||
|
||||
config = AutoscalingConfig(
|
||||
min_replicas=1,
|
||||
min_replicas=0,
|
||||
max_replicas=2,
|
||||
target_num_ongoing_requests_per_replica=1,
|
||||
upscale_delay_s=30.0,
|
||||
|
@ -259,15 +260,27 @@ def test_upscale_downscale_delay():
|
|||
|
||||
overload_requests = [100]
|
||||
|
||||
# Scale up when there are 0 replicas and current_handle_queued_queries > 0
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=[],
|
||||
curr_target_num_replicas=0,
|
||||
current_handle_queued_queries=1,
|
||||
)
|
||||
assert new_num_replicas == 1
|
||||
|
||||
# We should scale up only after enough consecutive scale-up decisions.
|
||||
for i in range(upscale_wait_periods):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=overload_requests, curr_target_num_replicas=1
|
||||
current_num_ongoing_requests=overload_requests,
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 1, i
|
||||
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=overload_requests, curr_target_num_replicas=1
|
||||
current_num_ongoing_requests=overload_requests,
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 2
|
||||
|
||||
|
@ -276,64 +289,84 @@ def test_upscale_downscale_delay():
|
|||
# We should scale down only after enough consecutive scale-down decisions.
|
||||
for i in range(downscale_wait_periods):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=no_requests, curr_target_num_replicas=2
|
||||
current_num_ongoing_requests=no_requests,
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 2, i
|
||||
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=no_requests, curr_target_num_replicas=2
|
||||
current_num_ongoing_requests=no_requests,
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 1
|
||||
assert new_num_replicas == 0
|
||||
|
||||
# Get some scale-up decisions, but not enough to trigger a scale up.
|
||||
for i in range(int(upscale_wait_periods / 2)):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=overload_requests, curr_target_num_replicas=1
|
||||
current_num_ongoing_requests=overload_requests,
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 1, i
|
||||
|
||||
# Interrupt with a scale-down decision.
|
||||
policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=[0], curr_target_num_replicas=1
|
||||
current_num_ongoing_requests=[0],
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
|
||||
# The counter should be reset, so it should require `upscale_wait_periods`
|
||||
# more periods before we actually scale up.
|
||||
for i in range(upscale_wait_periods):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=overload_requests, curr_target_num_replicas=1
|
||||
current_num_ongoing_requests=overload_requests,
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 1, i
|
||||
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=overload_requests, curr_target_num_replicas=1
|
||||
current_num_ongoing_requests=overload_requests,
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 2
|
||||
|
||||
# Get some scale-down decisions, but not enough to trigger a scale down.
|
||||
for i in range(int(downscale_wait_periods / 2)):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=no_requests, curr_target_num_replicas=2
|
||||
current_num_ongoing_requests=no_requests,
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 2, i
|
||||
|
||||
# Interrupt with a scale-up decision.
|
||||
policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=[100, 100], curr_target_num_replicas=2
|
||||
current_num_ongoing_requests=[100, 100],
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
|
||||
# The counter should be reset so it should require `downscale_wait_periods`
|
||||
# more periods before we actually scale down.
|
||||
for i in range(downscale_wait_periods):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=no_requests, curr_target_num_replicas=2
|
||||
current_num_ongoing_requests=no_requests,
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 2, i
|
||||
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=no_requests, curr_target_num_replicas=2
|
||||
current_num_ongoing_requests=no_requests,
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 1
|
||||
assert new_num_replicas == 0
|
||||
|
||||
|
||||
def test_replicas_delayed_startup():
|
||||
|
@ -348,21 +381,21 @@ def test_replicas_delayed_startup():
|
|||
|
||||
policy = BasicAutoscalingPolicy(config)
|
||||
|
||||
new_num_replicas = policy.get_decision_num_replicas([100], 1)
|
||||
new_num_replicas = policy.get_decision_num_replicas(1, [100], 0)
|
||||
assert new_num_replicas == 100
|
||||
|
||||
# New target is 100, but no new replicas finished spinning up during this
|
||||
# timestep.
|
||||
new_num_replicas = policy.get_decision_num_replicas([100], 100)
|
||||
new_num_replicas = policy.get_decision_num_replicas(100, [100], 0)
|
||||
assert new_num_replicas == 100
|
||||
|
||||
# Two new replicas spun up during this timestep.
|
||||
new_num_replicas = policy.get_decision_num_replicas([100, 20, 3], 100)
|
||||
new_num_replicas = policy.get_decision_num_replicas(100, [100, 20, 3], 0)
|
||||
assert new_num_replicas == 123
|
||||
|
||||
# A lot of queries got drained and a lot of replicas started up, but
|
||||
# new_num_replicas should not decrease, because of the downscale delay.
|
||||
new_num_replicas = policy.get_decision_num_replicas([6, 2, 1, 1], 123)
|
||||
new_num_replicas = policy.get_decision_num_replicas(123, [6, 2, 1, 1], 0)
|
||||
assert new_num_replicas == 123
|
||||
|
||||
|
||||
|
@ -396,6 +429,7 @@ def test_fluctuating_ongoing_requests(delay_s):
|
|||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=overload_requests,
|
||||
curr_target_num_replicas=1,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
if delay_s > 0:
|
||||
assert new_num_replicas == 1, trial
|
||||
|
@ -405,6 +439,7 @@ def test_fluctuating_ongoing_requests(delay_s):
|
|||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=underload_requests,
|
||||
curr_target_num_replicas=2,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
if delay_s > 0:
|
||||
assert new_num_replicas == 2, trial
|
||||
|
@ -434,7 +469,9 @@ def test_imbalanced_replicas(ongoing_requests):
|
|||
== config.target_num_ongoing_requests_per_replica
|
||||
):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=ongoing_requests, curr_target_num_replicas=4
|
||||
current_num_ongoing_requests=ongoing_requests,
|
||||
curr_target_num_replicas=4,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 4
|
||||
|
||||
|
@ -445,7 +482,9 @@ def test_imbalanced_replicas(ongoing_requests):
|
|||
< config.target_num_ongoing_requests_per_replica
|
||||
):
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=ongoing_requests, curr_target_num_replicas=4
|
||||
current_num_ongoing_requests=ongoing_requests,
|
||||
curr_target_num_replicas=4,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -463,7 +502,9 @@ def test_imbalanced_replicas(ongoing_requests):
|
|||
# is higher than target_num_ongoing_requests_per_replica
|
||||
else:
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=ongoing_requests, curr_target_num_replicas=4
|
||||
current_num_ongoing_requests=ongoing_requests,
|
||||
curr_target_num_replicas=4,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == 5
|
||||
|
||||
|
@ -485,7 +526,9 @@ def test_single_replica_receives_all_requests(ongoing_requests):
|
|||
policy = BasicAutoscalingPolicy(config)
|
||||
|
||||
new_num_replicas = policy.get_decision_num_replicas(
|
||||
current_num_ongoing_requests=ongoing_requests, curr_target_num_replicas=4
|
||||
current_num_ongoing_requests=ongoing_requests,
|
||||
curr_target_num_replicas=4,
|
||||
current_handle_queued_queries=0,
|
||||
)
|
||||
assert new_num_replicas == sum(ongoing_requests) / target_requests
|
||||
|
||||
|
@ -559,7 +602,7 @@ def test_e2e_intermediate_downscaling(serve_instance):
|
|||
@serve.deployment(
|
||||
_autoscaling_config={
|
||||
"metrics_interval_s": 0.1,
|
||||
"min_replicas": 1,
|
||||
"min_replicas": 0,
|
||||
"max_replicas": 20,
|
||||
"look_back_period_s": 0.2,
|
||||
"downscale_delay_s": 0.2,
|
||||
|
@ -598,7 +641,7 @@ def test_e2e_intermediate_downscaling(serve_instance):
|
|||
|
||||
signal.send.remote()
|
||||
# As the queue is drained, we should scale back down.
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) <= 1, timeout=30)
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) < 1, timeout=30)
|
||||
|
||||
# Make sure start time did not change for the deployment
|
||||
assert get_deployment_start_time(controller, A) == start_time
|
||||
|
@ -614,7 +657,7 @@ def test_e2e_update_autoscaling_deployment(serve_instance):
|
|||
@serve.deployment(
|
||||
_autoscaling_config={
|
||||
"metrics_interval_s": 0.1,
|
||||
"min_replicas": 1,
|
||||
"min_replicas": 0,
|
||||
"max_replicas": 10,
|
||||
"look_back_period_s": 0.2,
|
||||
"downscale_delay_s": 0.2,
|
||||
|
@ -636,7 +679,7 @@ def test_e2e_update_autoscaling_deployment(serve_instance):
|
|||
controller = serve_instance._controller
|
||||
start_time = get_deployment_start_time(controller, A)
|
||||
|
||||
assert get_num_running_replicas(controller, A) == 1
|
||||
assert get_num_running_replicas(controller, A) == 0
|
||||
|
||||
handle = A.get_handle()
|
||||
[handle.remote() for _ in range(400)]
|
||||
|
@ -683,6 +726,29 @@ def test_e2e_update_autoscaling_deployment(serve_instance):
|
|||
# Make sure start time did not change for the deployment
|
||||
assert get_deployment_start_time(controller, A) == start_time
|
||||
|
||||
# scale down to 0
|
||||
A.options(
|
||||
_autoscaling_config={
|
||||
"metrics_interval_s": 0.1,
|
||||
"min_replicas": 0,
|
||||
"max_replicas": 20,
|
||||
"look_back_period_s": 0.2,
|
||||
"downscale_delay_s": 0.2,
|
||||
"upscale_delay_s": 0.2,
|
||||
},
|
||||
version="v1",
|
||||
).deploy()
|
||||
print("Redeployed A.")
|
||||
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) < 1)
|
||||
assert get_num_running_replicas(controller, A) == 0
|
||||
|
||||
# scale up
|
||||
[handle.remote() for _ in range(400)]
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) > 0)
|
||||
signal.send.remote()
|
||||
wait_for_condition(lambda: get_num_running_replicas(controller, A) < 1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
|
||||
def test_e2e_raise_min_replicas(serve_instance):
|
||||
|
@ -691,7 +757,7 @@ def test_e2e_raise_min_replicas(serve_instance):
|
|||
@serve.deployment(
|
||||
_autoscaling_config={
|
||||
"metrics_interval_s": 0.1,
|
||||
"min_replicas": 1,
|
||||
"min_replicas": 0,
|
||||
"max_replicas": 10,
|
||||
"look_back_period_s": 0.2,
|
||||
"downscale_delay_s": 0.2,
|
||||
|
@ -713,13 +779,15 @@ def test_e2e_raise_min_replicas(serve_instance):
|
|||
controller = serve_instance._controller
|
||||
start_time = get_deployment_start_time(controller, A)
|
||||
|
||||
assert get_num_running_replicas(controller, A) == 0
|
||||
|
||||
handle = A.get_handle()
|
||||
[handle.remote() for _ in range(1)]
|
||||
print("Issued one request.")
|
||||
|
||||
time.sleep(2)
|
||||
assert get_num_running_replicas(controller, A) == 1
|
||||
print("Stayed at 1 replica.")
|
||||
print("Scale up to 1 replica.")
|
||||
|
||||
first_deployment_replicas = get_running_replica_tags(controller, A)
|
||||
|
||||
|
|
|
@ -833,13 +833,13 @@ def test_input_validation():
|
|||
with pytest.raises(ValidationError):
|
||||
Base.options(num_replicas="hi")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@serve.deployment(num_replicas=0)
|
||||
class ZeroNumReplicas:
|
||||
pass
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ValueError):
|
||||
Base.options(num_replicas=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
|
|
|
@ -51,14 +51,14 @@ from typing import Optional
|
|||
logger = logging.getLogger(__file__)
|
||||
|
||||
# Experiment configs
|
||||
DEFAULT_SMOKE_TEST_MIN_NUM_REPLICA = 1
|
||||
DEFAULT_SMOKE_TEST_MIN_NUM_REPLICA = 0
|
||||
DEFAULT_SMOKE_TEST_MAX_NUM_REPLICA = 8
|
||||
DEFAULT_SMOKE_TEST_NUM_DEPLOYMENTS = 4 # 2 replicas each
|
||||
|
||||
# TODO:(jiaodong) We should investigate and change this back to 1k
|
||||
# for now, we won't get valid latency numbers from wrk at 1k replica
|
||||
# likely due to request timeout.
|
||||
DEFAULT_FULL_TEST_MIN_NUM_REPLICA = 1
|
||||
DEFAULT_FULL_TEST_MIN_NUM_REPLICA = 0
|
||||
DEFAULT_FULL_TEST_MAX_NUM_REPLICA = 1000
|
||||
# TODO(simon): we should change this back to 100. But due to long poll issue
|
||||
# we temporarily downscoped this test.
|
||||
|
|
|
@ -50,9 +50,9 @@ from typing import Optional
|
|||
logger = logging.getLogger(__file__)
|
||||
|
||||
# Experiment configs
|
||||
DEFAULT_SMOKE_TEST_MIN_NUM_REPLICA = 1
|
||||
DEFAULT_SMOKE_TEST_MIN_NUM_REPLICA = 0
|
||||
DEFAULT_SMOKE_TEST_MAX_NUM_REPLICA = 4
|
||||
DEFAULT_FULL_TEST_MIN_NUM_REPLICA = 1
|
||||
DEFAULT_FULL_TEST_MIN_NUM_REPLICA = 0
|
||||
DEFAULT_FULL_TEST_MAX_NUM_REPLICA = 1000
|
||||
|
||||
# Deployment configs
|
||||
|
|
Loading…
Add table
Reference in a new issue