mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Serve] Don't recover from current state checkpoint (#19998)
This commit is contained in:
parent
ce8504b0b2
commit
b6bd4fd5f3
6 changed files with 87 additions and 80 deletions
|
@ -1,9 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorClass, ActorHandle
|
||||
from ray.actor import ActorHandle
|
||||
from ray.serve.config import DeploymentConfig, ReplicaConfig
|
||||
from ray.serve.autoscaling_policy import AutoscalingPolicy
|
||||
|
||||
|
@ -25,7 +25,8 @@ class DeploymentInfo:
|
|||
deployment_config: DeploymentConfig,
|
||||
replica_config: ReplicaConfig,
|
||||
start_time_ms: int,
|
||||
actor_def: Optional[ActorClass] = None,
|
||||
actor_name: Optional[str] = None,
|
||||
serialized_deployment_def: Optional[bytes] = None,
|
||||
version: Optional[str] = None,
|
||||
deployer_job_id: "Optional[ray._raylet.JobID]" = None,
|
||||
end_time_ms: Optional[int] = None,
|
||||
|
@ -34,13 +35,38 @@ class DeploymentInfo:
|
|||
self.replica_config = replica_config
|
||||
# The time when .deploy() was first called for this deployment.
|
||||
self.start_time_ms = start_time_ms
|
||||
self.actor_def = actor_def
|
||||
self.actor_name = actor_name
|
||||
self.serialized_deployment_def = serialized_deployment_def
|
||||
self.version = version
|
||||
self.deployer_job_id = deployer_job_id
|
||||
# The time when this deployment was deleted.
|
||||
self.end_time_ms = end_time_ms
|
||||
self.autoscaling_policy = autoscaling_policy
|
||||
|
||||
# ephermal state
|
||||
self._cached_actor_def = None
|
||||
|
||||
def __getstate__(self) -> Dict[Any, Any]:
|
||||
clean_dict = self.__dict__.copy()
|
||||
del clean_dict["_cached_actor_def"]
|
||||
return clean_dict
|
||||
|
||||
def __setstate__(self, d: Dict[Any, Any]) -> None:
|
||||
self.__dict__ = d
|
||||
self._cached_actor_def = None
|
||||
|
||||
@property
|
||||
def actor_def(self):
|
||||
# Delayed import as replica depends on this file.
|
||||
from ray.serve.replica import create_replica_wrapper
|
||||
if self._cached_actor_def is None:
|
||||
assert self.actor_name is not None
|
||||
assert self.serialized_deployment_def is not None
|
||||
self._cached_actor_def = ray.remote(
|
||||
create_replica_wrapper(self.actor_name,
|
||||
self.serialized_deployment_def))
|
||||
return self._cached_actor_def
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplicaName:
|
||||
|
|
|
@ -24,7 +24,6 @@ from ray.serve.config import DeploymentConfig, HTTPOptions, ReplicaConfig
|
|||
from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY
|
||||
from ray.serve.endpoint_state import EndpointState
|
||||
from ray.serve.http_state import HTTPState
|
||||
from ray.serve.replica import create_replica_wrapper
|
||||
from ray.serve.storage.checkpoint_path import make_kv_store
|
||||
from ray.serve.long_poll import LongPollHost
|
||||
from ray.serve.storage.kv_store import RayInternalKVStore
|
||||
|
@ -320,9 +319,8 @@ class ServeController:
|
|||
autoscaling_policy = None
|
||||
|
||||
deployment_info = DeploymentInfo(
|
||||
actor_def=ray.remote(
|
||||
create_replica_wrapper(
|
||||
name, replica_config.serialized_deployment_def)),
|
||||
actor_name=name,
|
||||
serialized_deployment_def=replica_config.serialized_deployment_def,
|
||||
version=version,
|
||||
deployment_config=deployment_config,
|
||||
replica_config=replica_config,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict, OrderedDict
|
||||
from enum import Enum
|
||||
|
@ -7,7 +8,7 @@ import os
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import ray
|
||||
from ray import cloudpickle, ObjectRef
|
||||
from ray import ObjectRef
|
||||
from ray.actor import ActorHandle
|
||||
from ray.serve.async_goal_manager import AsyncGoalManager
|
||||
from ray.serve.common import (DeploymentInfo, Duration, GoalId, ReplicaTag,
|
||||
|
@ -110,17 +111,6 @@ class ActorReplicaWrapper:
|
|||
# Populated in self.stop().
|
||||
self._graceful_shutdown_ref: ObjectRef = None
|
||||
|
||||
def __get_state__(self) -> Dict[Any, Any]:
|
||||
clean_dict = self.__dict__.copy()
|
||||
del clean_dict["_ready_obj_ref"]
|
||||
del clean_dict["_graceful_shutdown_ref"]
|
||||
return clean_dict
|
||||
|
||||
def __set_state__(self, d: Dict[Any, Any]) -> None:
|
||||
self.__dict__ = d
|
||||
self._ready_obj_ref = None
|
||||
self._graceful_shutdown_ref = None
|
||||
|
||||
@property
|
||||
def replica_tag(self) -> str:
|
||||
return self._replica_tag
|
||||
|
@ -372,12 +362,6 @@ class DeploymentReplica(VersionedReplica):
|
|||
self._start_time = None
|
||||
self._prev_slow_startup_warning_time = None
|
||||
|
||||
def __get_state__(self) -> Dict[Any, Any]:
|
||||
return self.__dict__.copy()
|
||||
|
||||
def __set_state__(self, d: Dict[Any, Any]) -> None:
|
||||
self.__dict__ = d
|
||||
|
||||
def get_running_replica_info(self) -> RunningReplicaInfo:
|
||||
return RunningReplicaInfo(
|
||||
deployment_name=self._deployment_name,
|
||||
|
@ -664,18 +648,8 @@ class DeploymentState:
|
|||
"""
|
||||
return (self._target_info, self._target_replicas, self._target_version)
|
||||
|
||||
def get_current_state_checkpoint_data(self):
|
||||
"""
|
||||
Return deployment's current state specific to the ray cluster it's
|
||||
running in. Might be lost or re-constructed upon ray cluster failure.
|
||||
"""
|
||||
return (self._rollback_info, self._curr_goal,
|
||||
self._prev_startup_warning,
|
||||
self._replica_constructor_retry_counter, self._replicas)
|
||||
|
||||
def get_checkpoint_data(self):
|
||||
return (self.get_target_state_checkpoint_data(),
|
||||
self.get_current_state_checkpoint_data())
|
||||
return self.get_target_state_checkpoint_data()
|
||||
|
||||
def recover_target_state_from_checkpoint(self, target_state_checkpoint):
|
||||
logger.info("Recovering target state for deployment "
|
||||
|
@ -683,18 +657,6 @@ class DeploymentState:
|
|||
(self._target_info, self._target_replicas,
|
||||
self._target_version) = target_state_checkpoint
|
||||
|
||||
def recover_current_state_from_checkpoint(self, current_state_checkpoint):
|
||||
logger.info("Recovering current state for deployment "
|
||||
f"{self._name} from checkpoint..")
|
||||
(self._rollback_info, self._curr_goal, self._prev_startup_warning,
|
||||
self._replica_constructor_retry_counter,
|
||||
self._replicas) = current_state_checkpoint
|
||||
|
||||
if self._curr_goal is not None:
|
||||
self._goal_manager.create_goal(self._curr_goal)
|
||||
|
||||
self._notify_running_replicas_changed()
|
||||
|
||||
def recover_current_state_from_replica_actor_names(
|
||||
self, replica_actor_names: List[str]):
|
||||
assert (
|
||||
|
@ -1288,23 +1250,19 @@ class DeploymentStateManager:
|
|||
checkpoint = self._kv_store.get(CHECKPOINT_KEY)
|
||||
if checkpoint is not None:
|
||||
(deployment_state_info,
|
||||
self._deleted_deployment_metadata) = cloudpickle.loads(checkpoint)
|
||||
self._deleted_deployment_metadata) = pickle.loads(checkpoint)
|
||||
|
||||
for deployment_tag, checkpoint_data in deployment_state_info.items(
|
||||
):
|
||||
deployment_state = self._create_deployment_state(
|
||||
deployment_tag)
|
||||
(target_state_checkpoint,
|
||||
current_state_checkpoint) = checkpoint_data
|
||||
|
||||
target_state_checkpoint = checkpoint_data
|
||||
deployment_state.recover_target_state_from_checkpoint(
|
||||
target_state_checkpoint)
|
||||
if len(deployment_to_current_replicas[deployment_tag]) > 0:
|
||||
deployment_state.recover_current_state_from_replica_actor_names( # noqa: E501
|
||||
deployment_to_current_replicas[deployment_tag])
|
||||
else:
|
||||
deployment_state.recover_current_state_from_checkpoint(
|
||||
current_state_checkpoint)
|
||||
self._deployment_states[deployment_tag] = deployment_state
|
||||
|
||||
def shutdown(self) -> List[GoalId]:
|
||||
|
@ -1342,8 +1300,11 @@ class DeploymentStateManager:
|
|||
}
|
||||
self._kv_store.put(
|
||||
CHECKPOINT_KEY,
|
||||
cloudpickle.dumps((deployment_state_info,
|
||||
self._deleted_deployment_metadata)))
|
||||
# NOTE(simon): Make sure to use pickle so we don't save any ray
|
||||
# object that relies on external state (e.g. gcs). For code object,
|
||||
# we are explicitly using cloudpickle to serialize them.
|
||||
pickle.dumps((deployment_state_info,
|
||||
self._deleted_deployment_metadata)))
|
||||
|
||||
def get_running_replica_infos(
|
||||
self,
|
||||
|
|
|
@ -154,7 +154,6 @@ def deployment_info(version: Optional[str] = None,
|
|||
user_config: Optional[Any] = None,
|
||||
**config_opts) -> Tuple[DeploymentInfo, DeploymentVersion]:
|
||||
info = DeploymentInfo(
|
||||
actor_def=None,
|
||||
version=version,
|
||||
start_time_ms=0,
|
||||
deployment_config=DeploymentConfig(
|
||||
|
|
|
@ -520,28 +520,50 @@ def test_local_store_recovery(ray_shutdown):
|
|||
def hello(_):
|
||||
return "hello"
|
||||
|
||||
def check():
|
||||
# https://github.com/ray-project/ray/issues/19987
|
||||
@serve.deployment
|
||||
def world(_):
|
||||
return "world"
|
||||
|
||||
def check(name):
|
||||
try:
|
||||
resp = requests.get("http://localhost:8000/hello")
|
||||
assert resp.text == "hello"
|
||||
resp = requests.get(f"http://localhost:8000/{name}")
|
||||
assert resp.text == name
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# https://github.com/ray-project/ray/issues/20159
|
||||
# https://github.com/ray-project/ray/issues/20158
|
||||
def clean_up_leaked_processes():
|
||||
import psutil
|
||||
for proc in psutil.process_iter():
|
||||
try:
|
||||
cmdline = " ".join(proc.cmdline())
|
||||
if "ray::" in cmdline:
|
||||
print(f"Kill {proc} {cmdline}")
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def crash():
|
||||
subprocess.call(["ray", "stop", "--force"])
|
||||
clean_up_leaked_processes()
|
||||
ray.shutdown()
|
||||
serve.shutdown()
|
||||
|
||||
serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}")
|
||||
hello.deploy()
|
||||
assert check()
|
||||
world.deploy()
|
||||
assert check("hello")
|
||||
assert check("world")
|
||||
crash()
|
||||
|
||||
# Simulate a crash
|
||||
|
||||
serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}")
|
||||
wait_for_condition(check)
|
||||
wait_for_condition(lambda: check("hello"))
|
||||
# wait_for_condition(lambda: check("world"))
|
||||
crash()
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import time
|
|||
import requests
|
||||
import uuid
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from serve_test_cluster_utils import setup_local_single_node_cluster
|
||||
|
||||
|
@ -22,7 +23,7 @@ from ray import serve
|
|||
from ray.serve.utils import logger
|
||||
|
||||
# Deployment configs
|
||||
DEFAULT_NUM_REPLICAS = 4
|
||||
DEFAULT_NUM_REPLICAS = 2
|
||||
DEFAULT_MAX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
|
@ -49,7 +50,10 @@ def main():
|
|||
# IS_SMOKE_TEST is set by args of releaser's e2e.py
|
||||
smoke_test = os.environ.get("IS_SMOKE_TEST", "1")
|
||||
if smoke_test == "1":
|
||||
checkpoint_path = "file://checkpoint.db"
|
||||
path = Path("checkpoint.db")
|
||||
checkpoint_path = f"file://{path}"
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
else:
|
||||
checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501
|
||||
|
||||
|
@ -57,20 +61,16 @@ def main():
|
|||
1, checkpoint_path=checkpoint_path, namespace=namespace)
|
||||
|
||||
# Deploy for the first time
|
||||
@serve.deployment(name="echo", num_replicas=DEFAULT_NUM_REPLICAS)
|
||||
class Echo:
|
||||
def __init__(self):
|
||||
return True
|
||||
@serve.deployment(num_replicas=DEFAULT_NUM_REPLICAS)
|
||||
def hello():
|
||||
return serve.get_replica_context().deployment
|
||||
|
||||
def __call__(self, request):
|
||||
return "hii"
|
||||
for name in ["hello", "world"]:
|
||||
hello.options(name=name).deploy()
|
||||
|
||||
Echo.deploy()
|
||||
|
||||
# Ensure endpoint is working
|
||||
for _ in range(5):
|
||||
response = request_with_retries("/echo/", timeout=3)
|
||||
assert response.text == "hii"
|
||||
for _ in range(5):
|
||||
response = request_with_retries(f"/{name}/", timeout=3)
|
||||
assert response.text == name
|
||||
|
||||
logger.info("Initial deployment successful with working endpoint.")
|
||||
|
||||
|
@ -87,9 +87,10 @@ def main():
|
|||
setup_local_single_node_cluster(
|
||||
1, checkpoint_path=checkpoint_path, namespace=namespace)
|
||||
|
||||
for _ in range(5):
|
||||
response = request_with_retries("/echo/", timeout=3)
|
||||
assert response.text == "hii"
|
||||
for name in ["hello", "world"]:
|
||||
for _ in range(5):
|
||||
response = request_with_retries(f"/{name}/", timeout=3)
|
||||
assert response.text == name
|
||||
|
||||
logger.info("Deployment recovery from s3 checkpoint is successful "
|
||||
"with working endpoint.")
|
||||
|
|
Loading…
Add table
Reference in a new issue