[serve] Refactor SystemState into EndpointState and BackendState (#13018)

This commit is contained in:
Edward Oakes 2020-12-21 20:39:13 -06:00 committed by GitHub
parent d5604eaba3
commit b52cce6632
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -93,33 +93,16 @@ class BackendInfo(BaseModel):
arbitrary_types_allowed = True
@dataclass
class SystemState:
backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict)
traffic_policies: Dict[EndpointTag, TrafficPolicy] = field(
default_factory=dict)
routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field(
default_factory=dict)
class EndpointState:
def __init__(self, checkpoint: bytes = None):
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict)
route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
if checkpoint is not None:
self.routes, self.traffic_policies = pickle.loads(checkpoint)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
tag: info.backend_config
for tag, info in self.backends.items()
}
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
self.backend_goal_ids = goal_id
def checkpoint(self):
return pickle.dumps((self.routes, self.traffic_policies))
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {}
@ -141,6 +124,32 @@ class SystemState:
return endpoints
class BackendState:
def __init__(self, checkpoint: bytes = None):
self.backends: Dict[BackendTag, BackendInfo] = dict()
if checkpoint is not None:
self.backends = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps(self.backends)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
tag: info.backend_config
for tag, info in self.backends.items()
}
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
@dataclass
class ActorStateReconciler:
controller_name: str = field(init=True)
@ -192,7 +201,7 @@ class ActorStateReconciler:
for replica_dict in self.backend_replicas.values()
]))
async def _start_backend_replica(self, current_state: SystemState,
async def _start_backend_replica(self, backend_state: BackendState,
backend_tag: BackendTag,
replica_tag: ReplicaTag) -> ActorHandle:
"""Start a replica and return its actor handle.
@ -210,7 +219,7 @@ class ActorStateReconciler:
except ValueError:
logger.debug("Starting replica '{}' for backend '{}'.".format(
replica_tag, backend_tag))
backend_info = current_state.get_backend(backend_tag)
backend_info = backend_state.get_backend(backend_tag)
replica_handle = ray.remote(backend_info.worker_class).options(
name=replica_name,
@ -284,12 +293,12 @@ class ActorStateReconciler:
self.backend_replicas_to_stop[backend_tag].append(replica_tag)
async def _enqueue_pending_scale_changes_loop(self,
current_state: SystemState):
backend_state: BackendState):
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
items():
for replica_tag in replicas_to_create:
replica_handle = await self._start_backend_replica(
current_state, backend_tag, replica_tag)
backend_state, backend_tag, replica_tag)
ready_future = replica_handle.ready.remote().as_future()
self.currently_starting_replicas[ready_future] = (
backend_tag, replica_tag, replica_handle)
@ -456,19 +465,19 @@ class ActorStateReconciler:
replica_tag] = ray.get_actor(replica_name)
async def _recover_from_checkpoint(
self, current_state: SystemState, controller: "ServeController"
self, backend_state: BackendState, controller: "ServeController"
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
self._recover_actor_handles()
autoscaling_policies = dict()
for backend, info in current_state.backends.items():
for backend, info in backend_state.backends.items():
metadata = info.backend_config.internal_metadata
if metadata.autoscaling_config is not None:
autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, metadata.autoscaling_config)
# Start/stop any pending backend replicas.
await self._enqueue_pending_scale_changes_loop(current_state)
await self._enqueue_pending_scale_changes_loop(backend_state)
await self.backend_control_loop()
return autoscaling_policies
@ -482,8 +491,8 @@ class FutureResult:
@dataclass
class Checkpoint:
goal_state: SystemState
current_state: SystemState
endpoint_state_checkpoint: bytes
backend_state_checkpoint: bytes
reconciler: ActorStateReconciler
# TODO(ilr) Rename reconciler to PendingState
inflight_reqs: Dict[uuid4, FutureResult]
@ -523,13 +532,6 @@ class ServeController:
detached: bool = False):
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name)
# Current State
self.current_state = SystemState()
# Goal State
# TODO(ilr) This is currently *unused* until the refactor of the serve
# controller.
self.goal_state = SystemState()
# ActorStateReconciler
self.actor_reconciler = ActorStateReconciler(controller_name, detached)
# backend -> AutoscalingPolicy
@ -556,10 +558,17 @@ class ServeController:
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
checkpoint = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint is None:
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
self.backend_state = BackendState()
self.endpoint_state = EndpointState()
else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
checkpoint=checkpoint.backend_state_checkpoint)
self.endpoint_state = EndpointState(
checkpoint=checkpoint.endpoint_state_checkpoint)
await self._recover_from_checkpoint(checkpoint)
# NOTE(simon): Currently we do all-to-all broadcast. This means
@ -618,17 +627,17 @@ class ServeController:
def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES,
self.current_state.traffic_policies,
self.endpoint_state.traffic_policies,
)
def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.BACKEND_CONFIGS,
self.current_state.get_backend_configs())
self.backend_state.get_backend_configs())
def notify_route_table_changed(self):
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self.current_state.routes)
self.endpoint_state.routes)
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request.
@ -652,19 +661,19 @@ class ServeController:
start = time.time()
checkpoint = pickle.dumps(
Checkpoint(self.goal_state, self.current_state,
self.actor_reconciler,
Checkpoint(self.endpoint_state.checkpoint(),
self.backend_state.checkpoint(), self.actor_reconciler,
self._serializable_inflight_results))
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))
if random.random(
) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
logger.warning("Intentionally crashing after checkpoint")
os._exit(0)
async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None:
async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Recover the instance state from the provided checkpoint.
This should be called in the constructor to ensure that the internal
@ -679,12 +688,9 @@ class ServeController:
start = time.time()
logger.info("Recovering from checkpoint")
restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.current_state = restored_checkpoint.current_state
self.actor_reconciler = checkpoint.reconciler
self.actor_reconciler = restored_checkpoint.reconciler
self._serializable_inflight_results = restored_checkpoint.inflight_reqs
self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items():
self._create_event_with_result(fut_result.requested_goal, uuid)
@ -704,7 +710,7 @@ class ServeController:
async def finish_recover_from_checkpoint():
assert self.write_lock.locked()
self.autoscaling_policies = await self.actor_reconciler.\
_recover_from_checkpoint(self.current_state, self)
_recover_from_checkpoint(self.backend_state, self)
self.write_lock.release()
logger.info(
"Recovered from checkpoint in {:.3f}s".format(time.time() -
@ -714,7 +720,7 @@ class ServeController:
asyncio.get_event_loop().create_task(finish_recover_from_checkpoint())
async def do_autoscale(self) -> None:
for backend, info in self.current_state.backends.items():
for backend, info in self.backend_state.backends.items():
if backend not in self.autoscaling_policies:
continue
@ -726,9 +732,6 @@ class ServeController:
async def reconcile_current_and_goal_backends(self):
pass
# backends_to_delete = set(
# self.current_state.backends.keys()).difference(
# self.goal_state.backends.keys())
async def run_control_loop(self) -> None:
while True:
@ -750,15 +753,15 @@ class ServeController:
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
"""Returns a dictionary of backend tag to backend config."""
return self.current_state.get_backend_configs()
return self.backend_state.get_backend_configs()
def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
"""Returns a dictionary of backend tag to backend config."""
return self.current_state.get_endpoints()
return self.endpoint_state.get_endpoints()
async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
if endpoint_name not in self.current_state.get_endpoints():
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name))
@ -766,13 +769,13 @@ class ServeController:
dict), "Traffic policy must be a dictionary."
for backend in traffic_dict:
if self.current_state.get_backend(backend) is None:
if self.backend_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))
traffic_policy = TrafficPolicy(traffic_dict)
self.current_state.traffic_policies[endpoint_name] = traffic_policy
self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy
return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
@ -795,20 +798,21 @@ class ServeController:
proportion: float) -> UUID:
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.current_state.get_endpoints():
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint_name))
if self.current_state.get_backend(backend_tag) is None:
if self.backend_state.get_backend(backend_tag) is None:
raise ValueError(
"Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag))
self.current_state.traffic_policies[endpoint_name].set_shadow(
self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion)
traffic_policy = self.current_state.traffic_policies[endpoint_name]
traffic_policy = self.endpoint_state.traffic_policies[
endpoint_name]
return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
@ -839,10 +843,10 @@ class ServeController:
# TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint."
if route in self.current_state.routes:
if route in self.endpoint_state.routes:
# Ensures this method is idempotent
if self.current_state.routes[route] == (endpoint, methods):
if self.endpoint_state.routes[route] == (endpoint, methods):
return
else:
@ -850,7 +854,7 @@ class ServeController:
"{} Route '{}' is already registered.".format(
err_prefix, route))
if endpoint in self.current_state.get_endpoints():
if endpoint in self.endpoint_state.get_endpoints():
raise ValueError(
"{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
@ -859,7 +863,7 @@ class ServeController:
"Registering route '{}' to endpoint '{}' with methods '{}'.".
format(route, endpoint, methods))
self.current_state.routes[route] = (endpoint, methods)
self.endpoint_state.routes[route] = (endpoint, methods)
# NOTE(edoakes): checkpoint is written in self._set_traffic.
return_uuid = await self._set_traffic(endpoint, traffic_dict)
@ -876,7 +880,7 @@ class ServeController:
# This method must be idempotent. We should validate that the
# specified endpoint exists on the client.
for route, (route_endpoint,
_) in self.current_state.routes.items():
_) in self.endpoint_state.routes.items():
if route_endpoint == endpoint:
route_to_delete = route
break
@ -885,11 +889,11 @@ class ServeController:
return
# Remove the routing entry.
del self.current_state.routes[route_to_delete]
del self.endpoint_state.routes[route_to_delete]
# Remove the traffic policy entry if it exists.
if endpoint in self.current_state.traffic_policies:
del self.current_state.traffic_policies[endpoint]
if endpoint in self.endpoint_state.traffic_policies:
del self.endpoint_state.traffic_policies[endpoint]
return_uuid = self._create_event_with_result({
route_to_delete: None,
@ -908,7 +912,7 @@ class ServeController:
"""Register a new backend under the specified tag."""
async with self.write_lock:
# Ensures this method is idempotent.
backend_info = self.current_state.get_backend(backend_tag)
backend_info = self.backend_state.get_backend(backend_tag)
if backend_info is not None:
if (backend_info.backend_config == backend_config
and backend_info.replica_config == replica_config):
@ -923,7 +927,7 @@ class ServeController:
worker_class=backend_replica,
backend_config=backend_config,
replica_config=replica_config)
self.current_state.add_backend(backend_tag, backend_info)
self.backend_state.add_backend(backend_tag, backend_info)
metadata = backend_config.internal_metadata
if metadata.autoscaling_config is not None:
self.autoscaling_policies[
@ -933,10 +937,10 @@ class ServeController:
try:
# This call should be to run control loop
self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag,
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
except RayServeException as e:
del self.current_state.backends[backend_tag]
del self.backend_state.backends[backend_tag]
raise e
return_uuid = self._create_event_with_result({
@ -947,7 +951,7 @@ class ServeController:
# crash while making the change.
self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.current_state)
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
@ -961,11 +965,11 @@ class ServeController:
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified backend exists on the client.
if self.current_state.get_backend(backend_tag) is None:
if self.backend_state.get_backend(backend_tag) is None:
return
# Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_policy in self.current_state.\
for endpoint, traffic_policy in self.endpoint_state.\
traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict):
@ -975,17 +979,15 @@ class ServeController:
"again.".format(backend_tag, endpoint))
# Scale its replicas down to 0. This will also remove the backend
# from self.current_state.backends and
# from self.backend_state.backends and
# self.actor_reconciler.backend_replicas.
self.goal_state.backends[backend_tag] = None
# This should be a call to the control loop
self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag, 0)
self.backend_state.backends, backend_tag, 0)
# Remove the backend's metadata.
del self.current_state.backends[backend_tag]
del self.backend_state.backends[backend_tag]
if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag]
@ -998,7 +1000,7 @@ class ServeController:
# after pushing the update.
self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.current_state)
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
@ -1008,24 +1010,24 @@ class ServeController:
config_options: BackendConfig) -> UUID:
"""Set the config for the specified backend."""
async with self.write_lock:
assert (self.current_state.get_backend(backend_tag)
assert (self.backend_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, BackendConfig)
stored_backend_config = self.current_state.get_backend(
stored_backend_config = self.backend_state.get_backend(
backend_tag).backend_config
backend_config = stored_backend_config.copy(
update=config_options.dict(exclude_unset=True))
backend_config._validate_complete()
self.current_state.get_backend(
self.backend_state.get_backend(
backend_tag).backend_config = backend_config
backend_info = self.current_state.get_backend(backend_tag)
backend_info = self.backend_state.get_backend(backend_tag)
# Scale the replicas with the new configuration.
# This should be to run the control loop
self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag,
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
return_uuid = self._create_event_with_result({
@ -1040,7 +1042,7 @@ class ServeController:
# (particularly for setting max_batch_size).
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.current_state)
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
@ -1049,9 +1051,9 @@ class ServeController:
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
"""Get the current config for the specified backend."""
assert (self.current_state.get_backend(backend_tag)
assert (self.backend_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag)
return self.current_state.get_backend(backend_tag).backend_config
return self.backend_state.get_backend(backend_tag).backend_config
def get_http_config(self):
"""Return the HTTP proxy configuration."""