[Serve] Don't recover from current state checkpoint (#19998)

This commit is contained in:
Simon Mo 2021-11-12 09:02:27 -08:00 committed by GitHub
parent ce8504b0b2
commit b6bd4fd5f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 87 additions and 80 deletions

View file

@ -1,9 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Any, Dict, Optional
from uuid import UUID from uuid import UUID
import ray import ray
from ray.actor import ActorClass, ActorHandle from ray.actor import ActorHandle
from ray.serve.config import DeploymentConfig, ReplicaConfig from ray.serve.config import DeploymentConfig, ReplicaConfig
from ray.serve.autoscaling_policy import AutoscalingPolicy from ray.serve.autoscaling_policy import AutoscalingPolicy
@ -25,7 +25,8 @@ class DeploymentInfo:
deployment_config: DeploymentConfig, deployment_config: DeploymentConfig,
replica_config: ReplicaConfig, replica_config: ReplicaConfig,
start_time_ms: int, start_time_ms: int,
actor_def: Optional[ActorClass] = None, actor_name: Optional[str] = None,
serialized_deployment_def: Optional[bytes] = None,
version: Optional[str] = None, version: Optional[str] = None,
deployer_job_id: "Optional[ray._raylet.JobID]" = None, deployer_job_id: "Optional[ray._raylet.JobID]" = None,
end_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None,
@ -34,13 +35,38 @@ class DeploymentInfo:
self.replica_config = replica_config self.replica_config = replica_config
# The time when .deploy() was first called for this deployment. # The time when .deploy() was first called for this deployment.
self.start_time_ms = start_time_ms self.start_time_ms = start_time_ms
self.actor_def = actor_def self.actor_name = actor_name
self.serialized_deployment_def = serialized_deployment_def
self.version = version self.version = version
self.deployer_job_id = deployer_job_id self.deployer_job_id = deployer_job_id
# The time when this deployment was deleted. # The time when this deployment was deleted.
self.end_time_ms = end_time_ms self.end_time_ms = end_time_ms
self.autoscaling_policy = autoscaling_policy self.autoscaling_policy = autoscaling_policy
# ephermal state
self._cached_actor_def = None
def __getstate__(self) -> Dict[Any, Any]:
clean_dict = self.__dict__.copy()
del clean_dict["_cached_actor_def"]
return clean_dict
def __setstate__(self, d: Dict[Any, Any]) -> None:
self.__dict__ = d
self._cached_actor_def = None
@property
def actor_def(self):
# Delayed import as replica depends on this file.
from ray.serve.replica import create_replica_wrapper
if self._cached_actor_def is None:
assert self.actor_name is not None
assert self.serialized_deployment_def is not None
self._cached_actor_def = ray.remote(
create_replica_wrapper(self.actor_name,
self.serialized_deployment_def))
return self._cached_actor_def
@dataclass @dataclass
class ReplicaName: class ReplicaName:

View file

@ -24,7 +24,6 @@ from ray.serve.config import DeploymentConfig, HTTPOptions, ReplicaConfig
from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY
from ray.serve.endpoint_state import EndpointState from ray.serve.endpoint_state import EndpointState
from ray.serve.http_state import HTTPState from ray.serve.http_state import HTTPState
from ray.serve.replica import create_replica_wrapper
from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.storage.checkpoint_path import make_kv_store
from ray.serve.long_poll import LongPollHost from ray.serve.long_poll import LongPollHost
from ray.serve.storage.kv_store import RayInternalKVStore from ray.serve.storage.kv_store import RayInternalKVStore
@ -320,9 +319,8 @@ class ServeController:
autoscaling_policy = None autoscaling_policy = None
deployment_info = DeploymentInfo( deployment_info = DeploymentInfo(
actor_def=ray.remote( actor_name=name,
create_replica_wrapper( serialized_deployment_def=replica_config.serialized_deployment_def,
name, replica_config.serialized_deployment_def)),
version=version, version=version,
deployment_config=deployment_config, deployment_config=deployment_config,
replica_config=replica_config, replica_config=replica_config,

View file

@ -1,5 +1,6 @@
import math import math
import json import json
import pickle
import time import time
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from enum import Enum from enum import Enum
@ -7,7 +8,7 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import ray import ray
from ray import cloudpickle, ObjectRef from ray import ObjectRef
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.serve.async_goal_manager import AsyncGoalManager from ray.serve.async_goal_manager import AsyncGoalManager
from ray.serve.common import (DeploymentInfo, Duration, GoalId, ReplicaTag, from ray.serve.common import (DeploymentInfo, Duration, GoalId, ReplicaTag,
@ -110,17 +111,6 @@ class ActorReplicaWrapper:
# Populated in self.stop(). # Populated in self.stop().
self._graceful_shutdown_ref: ObjectRef = None self._graceful_shutdown_ref: ObjectRef = None
def __get_state__(self) -> Dict[Any, Any]:
clean_dict = self.__dict__.copy()
del clean_dict["_ready_obj_ref"]
del clean_dict["_graceful_shutdown_ref"]
return clean_dict
def __set_state__(self, d: Dict[Any, Any]) -> None:
self.__dict__ = d
self._ready_obj_ref = None
self._graceful_shutdown_ref = None
@property @property
def replica_tag(self) -> str: def replica_tag(self) -> str:
return self._replica_tag return self._replica_tag
@ -372,12 +362,6 @@ class DeploymentReplica(VersionedReplica):
self._start_time = None self._start_time = None
self._prev_slow_startup_warning_time = None self._prev_slow_startup_warning_time = None
def __get_state__(self) -> Dict[Any, Any]:
return self.__dict__.copy()
def __set_state__(self, d: Dict[Any, Any]) -> None:
self.__dict__ = d
def get_running_replica_info(self) -> RunningReplicaInfo: def get_running_replica_info(self) -> RunningReplicaInfo:
return RunningReplicaInfo( return RunningReplicaInfo(
deployment_name=self._deployment_name, deployment_name=self._deployment_name,
@ -664,18 +648,8 @@ class DeploymentState:
""" """
return (self._target_info, self._target_replicas, self._target_version) return (self._target_info, self._target_replicas, self._target_version)
def get_current_state_checkpoint_data(self):
"""
Return deployment's current state specific to the ray cluster it's
running in. Might be lost or re-constructed upon ray cluster failure.
"""
return (self._rollback_info, self._curr_goal,
self._prev_startup_warning,
self._replica_constructor_retry_counter, self._replicas)
def get_checkpoint_data(self): def get_checkpoint_data(self):
return (self.get_target_state_checkpoint_data(), return self.get_target_state_checkpoint_data()
self.get_current_state_checkpoint_data())
def recover_target_state_from_checkpoint(self, target_state_checkpoint): def recover_target_state_from_checkpoint(self, target_state_checkpoint):
logger.info("Recovering target state for deployment " logger.info("Recovering target state for deployment "
@ -683,18 +657,6 @@ class DeploymentState:
(self._target_info, self._target_replicas, (self._target_info, self._target_replicas,
self._target_version) = target_state_checkpoint self._target_version) = target_state_checkpoint
def recover_current_state_from_checkpoint(self, current_state_checkpoint):
logger.info("Recovering current state for deployment "
f"{self._name} from checkpoint..")
(self._rollback_info, self._curr_goal, self._prev_startup_warning,
self._replica_constructor_retry_counter,
self._replicas) = current_state_checkpoint
if self._curr_goal is not None:
self._goal_manager.create_goal(self._curr_goal)
self._notify_running_replicas_changed()
def recover_current_state_from_replica_actor_names( def recover_current_state_from_replica_actor_names(
self, replica_actor_names: List[str]): self, replica_actor_names: List[str]):
assert ( assert (
@ -1288,23 +1250,19 @@ class DeploymentStateManager:
checkpoint = self._kv_store.get(CHECKPOINT_KEY) checkpoint = self._kv_store.get(CHECKPOINT_KEY)
if checkpoint is not None: if checkpoint is not None:
(deployment_state_info, (deployment_state_info,
self._deleted_deployment_metadata) = cloudpickle.loads(checkpoint) self._deleted_deployment_metadata) = pickle.loads(checkpoint)
for deployment_tag, checkpoint_data in deployment_state_info.items( for deployment_tag, checkpoint_data in deployment_state_info.items(
): ):
deployment_state = self._create_deployment_state( deployment_state = self._create_deployment_state(
deployment_tag) deployment_tag)
(target_state_checkpoint,
current_state_checkpoint) = checkpoint_data
target_state_checkpoint = checkpoint_data
deployment_state.recover_target_state_from_checkpoint( deployment_state.recover_target_state_from_checkpoint(
target_state_checkpoint) target_state_checkpoint)
if len(deployment_to_current_replicas[deployment_tag]) > 0: if len(deployment_to_current_replicas[deployment_tag]) > 0:
deployment_state.recover_current_state_from_replica_actor_names( # noqa: E501 deployment_state.recover_current_state_from_replica_actor_names( # noqa: E501
deployment_to_current_replicas[deployment_tag]) deployment_to_current_replicas[deployment_tag])
else:
deployment_state.recover_current_state_from_checkpoint(
current_state_checkpoint)
self._deployment_states[deployment_tag] = deployment_state self._deployment_states[deployment_tag] = deployment_state
def shutdown(self) -> List[GoalId]: def shutdown(self) -> List[GoalId]:
@ -1342,7 +1300,10 @@ class DeploymentStateManager:
} }
self._kv_store.put( self._kv_store.put(
CHECKPOINT_KEY, CHECKPOINT_KEY,
cloudpickle.dumps((deployment_state_info, # NOTE(simon): Make sure to use pickle so we don't save any ray
# object that relies on external state (e.g. gcs). For code object,
# we are explicitly using cloudpickle to serialize them.
pickle.dumps((deployment_state_info,
self._deleted_deployment_metadata))) self._deleted_deployment_metadata)))
def get_running_replica_infos( def get_running_replica_infos(

View file

@ -154,7 +154,6 @@ def deployment_info(version: Optional[str] = None,
user_config: Optional[Any] = None, user_config: Optional[Any] = None,
**config_opts) -> Tuple[DeploymentInfo, DeploymentVersion]: **config_opts) -> Tuple[DeploymentInfo, DeploymentVersion]:
info = DeploymentInfo( info = DeploymentInfo(
actor_def=None,
version=version, version=version,
start_time_ms=0, start_time_ms=0,
deployment_config=DeploymentConfig( deployment_config=DeploymentConfig(

View file

@ -520,28 +520,50 @@ def test_local_store_recovery(ray_shutdown):
def hello(_): def hello(_):
return "hello" return "hello"
def check(): # https://github.com/ray-project/ray/issues/19987
@serve.deployment
def world(_):
return "world"
def check(name):
try: try:
resp = requests.get("http://localhost:8000/hello") resp = requests.get(f"http://localhost:8000/{name}")
assert resp.text == "hello" assert resp.text == name
return True return True
except Exception: except Exception:
return False return False
# https://github.com/ray-project/ray/issues/20159
# https://github.com/ray-project/ray/issues/20158
def clean_up_leaked_processes():
import psutil
for proc in psutil.process_iter():
try:
cmdline = " ".join(proc.cmdline())
if "ray::" in cmdline:
print(f"Kill {proc} {cmdline}")
proc.kill()
except Exception:
pass
def crash(): def crash():
subprocess.call(["ray", "stop", "--force"]) subprocess.call(["ray", "stop", "--force"])
clean_up_leaked_processes()
ray.shutdown() ray.shutdown()
serve.shutdown() serve.shutdown()
serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}") serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}")
hello.deploy() hello.deploy()
assert check() world.deploy()
assert check("hello")
assert check("world")
crash() crash()
# Simulate a crash # Simulate a crash
serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}") serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}")
wait_for_condition(check) wait_for_condition(lambda: check("hello"))
# wait_for_condition(lambda: check("world"))
crash() crash()

View file

@ -12,6 +12,7 @@ import time
import requests import requests
import uuid import uuid
import os import os
from pathlib import Path
from serve_test_cluster_utils import setup_local_single_node_cluster from serve_test_cluster_utils import setup_local_single_node_cluster
@ -22,7 +23,7 @@ from ray import serve
from ray.serve.utils import logger from ray.serve.utils import logger
# Deployment configs # Deployment configs
DEFAULT_NUM_REPLICAS = 4 DEFAULT_NUM_REPLICAS = 2
DEFAULT_MAX_BATCH_SIZE = 16 DEFAULT_MAX_BATCH_SIZE = 16
@ -49,7 +50,10 @@ def main():
# IS_SMOKE_TEST is set by args of releaser's e2e.py # IS_SMOKE_TEST is set by args of releaser's e2e.py
smoke_test = os.environ.get("IS_SMOKE_TEST", "1") smoke_test = os.environ.get("IS_SMOKE_TEST", "1")
if smoke_test == "1": if smoke_test == "1":
checkpoint_path = "file://checkpoint.db" path = Path("checkpoint.db")
checkpoint_path = f"file://{path}"
if path.exists():
path.unlink()
else: else:
checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501 checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501
@ -57,20 +61,16 @@ def main():
1, checkpoint_path=checkpoint_path, namespace=namespace) 1, checkpoint_path=checkpoint_path, namespace=namespace)
# Deploy for the first time # Deploy for the first time
@serve.deployment(name="echo", num_replicas=DEFAULT_NUM_REPLICAS) @serve.deployment(num_replicas=DEFAULT_NUM_REPLICAS)
class Echo: def hello():
def __init__(self): return serve.get_replica_context().deployment
return True
def __call__(self, request): for name in ["hello", "world"]:
return "hii" hello.options(name=name).deploy()
Echo.deploy()
# Ensure endpoint is working
for _ in range(5): for _ in range(5):
response = request_with_retries("/echo/", timeout=3) response = request_with_retries(f"/{name}/", timeout=3)
assert response.text == "hii" assert response.text == name
logger.info("Initial deployment successful with working endpoint.") logger.info("Initial deployment successful with working endpoint.")
@ -87,9 +87,10 @@ def main():
setup_local_single_node_cluster( setup_local_single_node_cluster(
1, checkpoint_path=checkpoint_path, namespace=namespace) 1, checkpoint_path=checkpoint_path, namespace=namespace)
for name in ["hello", "world"]:
for _ in range(5): for _ in range(5):
response = request_with_retries("/echo/", timeout=3) response = request_with_retries(f"/{name}/", timeout=3)
assert response.text == "hii" assert response.text == name
logger.info("Deployment recovery from s3 checkpoint is successful " logger.info("Deployment recovery from s3 checkpoint is successful "
"with working endpoint.") "with working endpoint.")