[serve] Refactor SystemState into EndpointState and BackendState (#13018)

This commit is contained in:
Edward Oakes 2020-12-21 20:39:13 -06:00 committed by GitHub
parent d5604eaba3
commit b52cce6632
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -93,33 +93,16 @@ class BackendInfo(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@dataclass class EndpointState:
class SystemState: def __init__(self, checkpoint: bytes = None):
backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict) self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
traffic_policies: Dict[EndpointTag, TrafficPolicy] = field( self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
default_factory=dict)
routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field(
default_factory=dict)
backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict) if checkpoint is not None:
traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict) self.routes, self.traffic_policies = pickle.loads(checkpoint)
route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: def checkpoint(self):
return { return pickle.dumps((self.routes, self.traffic_policies))
tag: info.backend_config
for tag, info in self.backends.items()
}
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
self.backend_goal_ids = goal_id
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {} endpoints = {}
@ -141,6 +124,32 @@ class SystemState:
return endpoints return endpoints
class BackendState:
def __init__(self, checkpoint: bytes = None):
self.backends: Dict[BackendTag, BackendInfo] = dict()
if checkpoint is not None:
self.backends = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps(self.backends)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
tag: info.backend_config
for tag, info in self.backends.items()
}
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
@dataclass @dataclass
class ActorStateReconciler: class ActorStateReconciler:
controller_name: str = field(init=True) controller_name: str = field(init=True)
@ -192,7 +201,7 @@ class ActorStateReconciler:
for replica_dict in self.backend_replicas.values() for replica_dict in self.backend_replicas.values()
])) ]))
async def _start_backend_replica(self, current_state: SystemState, async def _start_backend_replica(self, backend_state: BackendState,
backend_tag: BackendTag, backend_tag: BackendTag,
replica_tag: ReplicaTag) -> ActorHandle: replica_tag: ReplicaTag) -> ActorHandle:
"""Start a replica and return its actor handle. """Start a replica and return its actor handle.
@ -210,7 +219,7 @@ class ActorStateReconciler:
except ValueError: except ValueError:
logger.debug("Starting replica '{}' for backend '{}'.".format( logger.debug("Starting replica '{}' for backend '{}'.".format(
replica_tag, backend_tag)) replica_tag, backend_tag))
backend_info = current_state.get_backend(backend_tag) backend_info = backend_state.get_backend(backend_tag)
replica_handle = ray.remote(backend_info.worker_class).options( replica_handle = ray.remote(backend_info.worker_class).options(
name=replica_name, name=replica_name,
@ -284,12 +293,12 @@ class ActorStateReconciler:
self.backend_replicas_to_stop[backend_tag].append(replica_tag) self.backend_replicas_to_stop[backend_tag].append(replica_tag)
async def _enqueue_pending_scale_changes_loop(self, async def _enqueue_pending_scale_changes_loop(self,
current_state: SystemState): backend_state: BackendState):
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
items(): items():
for replica_tag in replicas_to_create: for replica_tag in replicas_to_create:
replica_handle = await self._start_backend_replica( replica_handle = await self._start_backend_replica(
current_state, backend_tag, replica_tag) backend_state, backend_tag, replica_tag)
ready_future = replica_handle.ready.remote().as_future() ready_future = replica_handle.ready.remote().as_future()
self.currently_starting_replicas[ready_future] = ( self.currently_starting_replicas[ready_future] = (
backend_tag, replica_tag, replica_handle) backend_tag, replica_tag, replica_handle)
@ -456,19 +465,19 @@ class ActorStateReconciler:
replica_tag] = ray.get_actor(replica_name) replica_tag] = ray.get_actor(replica_name)
async def _recover_from_checkpoint( async def _recover_from_checkpoint(
self, current_state: SystemState, controller: "ServeController" self, backend_state: BackendState, controller: "ServeController"
) -> Dict[BackendTag, BasicAutoscalingPolicy]: ) -> Dict[BackendTag, BasicAutoscalingPolicy]:
self._recover_actor_handles() self._recover_actor_handles()
autoscaling_policies = dict() autoscaling_policies = dict()
for backend, info in current_state.backends.items(): for backend, info in backend_state.backends.items():
metadata = info.backend_config.internal_metadata metadata = info.backend_config.internal_metadata
if metadata.autoscaling_config is not None: if metadata.autoscaling_config is not None:
autoscaling_policies[backend] = BasicAutoscalingPolicy( autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, metadata.autoscaling_config) backend, metadata.autoscaling_config)
# Start/stop any pending backend replicas. # Start/stop any pending backend replicas.
await self._enqueue_pending_scale_changes_loop(current_state) await self._enqueue_pending_scale_changes_loop(backend_state)
await self.backend_control_loop() await self.backend_control_loop()
return autoscaling_policies return autoscaling_policies
@ -482,8 +491,8 @@ class FutureResult:
@dataclass @dataclass
class Checkpoint: class Checkpoint:
goal_state: SystemState endpoint_state_checkpoint: bytes
current_state: SystemState backend_state_checkpoint: bytes
reconciler: ActorStateReconciler reconciler: ActorStateReconciler
# TODO(ilr) Rename reconciler to PendingState # TODO(ilr) Rename reconciler to PendingState
inflight_reqs: Dict[uuid4, FutureResult] inflight_reqs: Dict[uuid4, FutureResult]
@ -523,13 +532,6 @@ class ServeController:
detached: bool = False): detached: bool = False):
# Used to read/write checkpoints. # Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name) self.kv_store = RayInternalKVStore(namespace=controller_name)
# Current State
self.current_state = SystemState()
# Goal State
# TODO(ilr) This is currently *unused* until the refactor of the serve
# controller.
self.goal_state = SystemState()
# ActorStateReconciler
self.actor_reconciler = ActorStateReconciler(controller_name, detached) self.actor_reconciler = ActorStateReconciler(controller_name, detached)
# backend -> AutoscalingPolicy # backend -> AutoscalingPolicy
@ -556,10 +558,17 @@ class ServeController:
self.inflight_results: Dict[UUID, asyncio.Event] = dict() self.inflight_results: Dict[UUID, asyncio.Event] = dict()
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
checkpoint = self.kv_store.get(CHECKPOINT_KEY) checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint is None: if checkpoint_bytes is None:
logger.debug("No checkpoint found") logger.debug("No checkpoint found")
self.backend_state = BackendState()
self.endpoint_state = EndpointState()
else: else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
checkpoint=checkpoint.backend_state_checkpoint)
self.endpoint_state = EndpointState(
checkpoint=checkpoint.endpoint_state_checkpoint)
await self._recover_from_checkpoint(checkpoint) await self._recover_from_checkpoint(checkpoint)
# NOTE(simon): Currently we do all-to-all broadcast. This means # NOTE(simon): Currently we do all-to-all broadcast. This means
@ -618,17 +627,17 @@ class ServeController:
def notify_traffic_policies_changed(self): def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed( self.long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES, LongPollKey.TRAFFIC_POLICIES,
self.current_state.traffic_policies, self.endpoint_state.traffic_policies,
) )
def notify_backend_configs_changed(self): def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed( self.long_poll_host.notify_changed(
LongPollKey.BACKEND_CONFIGS, LongPollKey.BACKEND_CONFIGS,
self.current_state.get_backend_configs()) self.backend_state.get_backend_configs())
def notify_route_table_changed(self): def notify_route_table_changed(self):
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE, self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self.current_state.routes) self.endpoint_state.routes)
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request. """Proxy long pull client's listen request.
@ -652,19 +661,19 @@ class ServeController:
start = time.time() start = time.time()
checkpoint = pickle.dumps( checkpoint = pickle.dumps(
Checkpoint(self.goal_state, self.current_state, Checkpoint(self.endpoint_state.checkpoint(),
self.actor_reconciler, self.backend_state.checkpoint(), self.actor_reconciler,
self._serializable_inflight_results)) self._serializable_inflight_results))
self.kv_store.put(CHECKPOINT_KEY, checkpoint) self.kv_store.put(CHECKPOINT_KEY, checkpoint)
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))
if random.random( if random.random(
) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached: ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
logger.warning("Intentionally crashing after checkpoint") logger.warning("Intentionally crashing after checkpoint")
os._exit(0) os._exit(0)
async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None: async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Recover the instance state from the provided checkpoint. """Recover the instance state from the provided checkpoint.
This should be called in the constructor to ensure that the internal This should be called in the constructor to ensure that the internal
@ -679,12 +688,9 @@ class ServeController:
start = time.time() start = time.time()
logger.info("Recovering from checkpoint") logger.info("Recovering from checkpoint")
restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) self.actor_reconciler = checkpoint.reconciler
self.current_state = restored_checkpoint.current_state
self.actor_reconciler = restored_checkpoint.reconciler self._serializable_inflight_results = checkpoint.inflight_reqs
self._serializable_inflight_results = restored_checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items(): for uuid, fut_result in self._serializable_inflight_results.items():
self._create_event_with_result(fut_result.requested_goal, uuid) self._create_event_with_result(fut_result.requested_goal, uuid)
@ -704,7 +710,7 @@ class ServeController:
async def finish_recover_from_checkpoint(): async def finish_recover_from_checkpoint():
assert self.write_lock.locked() assert self.write_lock.locked()
self.autoscaling_policies = await self.actor_reconciler.\ self.autoscaling_policies = await self.actor_reconciler.\
_recover_from_checkpoint(self.current_state, self) _recover_from_checkpoint(self.backend_state, self)
self.write_lock.release() self.write_lock.release()
logger.info( logger.info(
"Recovered from checkpoint in {:.3f}s".format(time.time() - "Recovered from checkpoint in {:.3f}s".format(time.time() -
@ -714,7 +720,7 @@ class ServeController:
asyncio.get_event_loop().create_task(finish_recover_from_checkpoint()) asyncio.get_event_loop().create_task(finish_recover_from_checkpoint())
async def do_autoscale(self) -> None: async def do_autoscale(self) -> None:
for backend, info in self.current_state.backends.items(): for backend, info in self.backend_state.backends.items():
if backend not in self.autoscaling_policies: if backend not in self.autoscaling_policies:
continue continue
@ -726,9 +732,6 @@ class ServeController:
async def reconcile_current_and_goal_backends(self): async def reconcile_current_and_goal_backends(self):
pass pass
# backends_to_delete = set(
# self.current_state.backends.keys()).difference(
# self.goal_state.backends.keys())
async def run_control_loop(self) -> None: async def run_control_loop(self) -> None:
while True: while True:
@ -750,15 +753,15 @@ class ServeController:
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
"""Returns a dictionary of backend tag to backend config.""" """Returns a dictionary of backend tag to backend config."""
return self.current_state.get_backend_configs() return self.backend_state.get_backend_configs()
def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]: def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
"""Returns a dictionary of backend tag to backend config.""" """Returns a dictionary of backend tag to backend config."""
return self.current_state.get_endpoints() return self.endpoint_state.get_endpoints()
async def _set_traffic(self, endpoint_name: str, async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID: traffic_dict: Dict[str, float]) -> UUID:
if endpoint_name not in self.current_state.get_endpoints(): if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'" raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name)) " that is not registered.".format(endpoint_name))
@ -766,13 +769,13 @@ class ServeController:
dict), "Traffic policy must be a dictionary." dict), "Traffic policy must be a dictionary."
for backend in traffic_dict: for backend in traffic_dict:
if self.current_state.get_backend(backend) is None: if self.backend_state.get_backend(backend) is None:
raise ValueError( raise ValueError(
"Attempted to assign traffic to a backend '{}' that " "Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend)) "is not registered.".format(backend))
traffic_policy = TrafficPolicy(traffic_dict) traffic_policy = TrafficPolicy(traffic_dict)
self.current_state.traffic_policies[endpoint_name] = traffic_policy self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy
return_uuid = self._create_event_with_result({ return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy endpoint_name: traffic_policy
@ -795,20 +798,21 @@ class ServeController:
proportion: float) -> UUID: proportion: float) -> UUID:
"""Shadow traffic from the endpoint to the backend.""" """Shadow traffic from the endpoint to the backend."""
async with self.write_lock: async with self.write_lock:
if endpoint_name not in self.current_state.get_endpoints(): if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to shadow traffic from an " raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered." "endpoint '{}' that is not registered."
.format(endpoint_name)) .format(endpoint_name))
if self.current_state.get_backend(backend_tag) is None: if self.backend_state.get_backend(backend_tag) is None:
raise ValueError( raise ValueError(
"Attempted to shadow traffic to a backend '{}' that " "Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag)) "is not registered.".format(backend_tag))
self.current_state.traffic_policies[endpoint_name].set_shadow( self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion) backend_tag, proportion)
traffic_policy = self.current_state.traffic_policies[endpoint_name] traffic_policy = self.endpoint_state.traffic_policies[
endpoint_name]
return_uuid = self._create_event_with_result({ return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy endpoint_name: traffic_policy
@ -839,10 +843,10 @@ class ServeController:
# TODO(edoakes): move this to client side. # TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint." err_prefix = "Cannot create endpoint."
if route in self.current_state.routes: if route in self.endpoint_state.routes:
# Ensures this method is idempotent # Ensures this method is idempotent
if self.current_state.routes[route] == (endpoint, methods): if self.endpoint_state.routes[route] == (endpoint, methods):
return return
else: else:
@ -850,7 +854,7 @@ class ServeController:
"{} Route '{}' is already registered.".format( "{} Route '{}' is already registered.".format(
err_prefix, route)) err_prefix, route))
if endpoint in self.current_state.get_endpoints(): if endpoint in self.endpoint_state.get_endpoints():
raise ValueError( raise ValueError(
"{} Endpoint '{}' is already registered.".format( "{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint)) err_prefix, endpoint))
@ -859,7 +863,7 @@ class ServeController:
"Registering route '{}' to endpoint '{}' with methods '{}'.". "Registering route '{}' to endpoint '{}' with methods '{}'.".
format(route, endpoint, methods)) format(route, endpoint, methods))
self.current_state.routes[route] = (endpoint, methods) self.endpoint_state.routes[route] = (endpoint, methods)
# NOTE(edoakes): checkpoint is written in self._set_traffic. # NOTE(edoakes): checkpoint is written in self._set_traffic.
return_uuid = await self._set_traffic(endpoint, traffic_dict) return_uuid = await self._set_traffic(endpoint, traffic_dict)
@ -876,7 +880,7 @@ class ServeController:
# This method must be idempotent. We should validate that the # This method must be idempotent. We should validate that the
# specified endpoint exists on the client. # specified endpoint exists on the client.
for route, (route_endpoint, for route, (route_endpoint,
_) in self.current_state.routes.items(): _) in self.endpoint_state.routes.items():
if route_endpoint == endpoint: if route_endpoint == endpoint:
route_to_delete = route route_to_delete = route
break break
@ -885,11 +889,11 @@ class ServeController:
return return
# Remove the routing entry. # Remove the routing entry.
del self.current_state.routes[route_to_delete] del self.endpoint_state.routes[route_to_delete]
# Remove the traffic policy entry if it exists. # Remove the traffic policy entry if it exists.
if endpoint in self.current_state.traffic_policies: if endpoint in self.endpoint_state.traffic_policies:
del self.current_state.traffic_policies[endpoint] del self.endpoint_state.traffic_policies[endpoint]
return_uuid = self._create_event_with_result({ return_uuid = self._create_event_with_result({
route_to_delete: None, route_to_delete: None,
@ -908,7 +912,7 @@ class ServeController:
"""Register a new backend under the specified tag.""" """Register a new backend under the specified tag."""
async with self.write_lock: async with self.write_lock:
# Ensures this method is idempotent. # Ensures this method is idempotent.
backend_info = self.current_state.get_backend(backend_tag) backend_info = self.backend_state.get_backend(backend_tag)
if backend_info is not None: if backend_info is not None:
if (backend_info.backend_config == backend_config if (backend_info.backend_config == backend_config
and backend_info.replica_config == replica_config): and backend_info.replica_config == replica_config):
@ -923,7 +927,7 @@ class ServeController:
worker_class=backend_replica, worker_class=backend_replica,
backend_config=backend_config, backend_config=backend_config,
replica_config=replica_config) replica_config=replica_config)
self.current_state.add_backend(backend_tag, backend_info) self.backend_state.add_backend(backend_tag, backend_info)
metadata = backend_config.internal_metadata metadata = backend_config.internal_metadata
if metadata.autoscaling_config is not None: if metadata.autoscaling_config is not None:
self.autoscaling_policies[ self.autoscaling_policies[
@ -933,10 +937,10 @@ class ServeController:
try: try:
# This call should be to run control loop # This call should be to run control loop
self.actor_reconciler._scale_backend_replicas( self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag, self.backend_state.backends, backend_tag,
backend_config.num_replicas) backend_config.num_replicas)
except RayServeException as e: except RayServeException as e:
del self.current_state.backends[backend_tag] del self.backend_state.backends[backend_tag]
raise e raise e
return_uuid = self._create_event_with_result({ return_uuid = self._create_event_with_result({
@ -947,7 +951,7 @@ class ServeController:
# crash while making the change. # crash while making the change.
self._checkpoint() self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop( await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.current_state) self.backend_state)
await self.actor_reconciler.backend_control_loop() await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed() self.notify_replica_handles_changed()
@ -961,11 +965,11 @@ class ServeController:
async with self.write_lock: async with self.write_lock:
# This method must be idempotent. We should validate that the # This method must be idempotent. We should validate that the
# specified backend exists on the client. # specified backend exists on the client.
if self.current_state.get_backend(backend_tag) is None: if self.backend_state.get_backend(backend_tag) is None:
return return
# Check that the specified backend isn't used by any endpoints. # Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_policy in self.current_state.\ for endpoint, traffic_policy in self.endpoint_state.\
traffic_policies.items(): traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict): or backend_tag in traffic_policy.shadow_dict):
@ -975,17 +979,15 @@ class ServeController:
"again.".format(backend_tag, endpoint)) "again.".format(backend_tag, endpoint))
# Scale its replicas down to 0. This will also remove the backend # Scale its replicas down to 0. This will also remove the backend
# from self.current_state.backends and # from self.backend_state.backends and
# self.actor_reconciler.backend_replicas. # self.actor_reconciler.backend_replicas.
self.goal_state.backends[backend_tag] = None
# This should be a call to the control loop # This should be a call to the control loop
self.actor_reconciler._scale_backend_replicas( self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag, 0) self.backend_state.backends, backend_tag, 0)
# Remove the backend's metadata. # Remove the backend's metadata.
del self.current_state.backends[backend_tag] del self.backend_state.backends[backend_tag]
if backend_tag in self.autoscaling_policies: if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag] del self.autoscaling_policies[backend_tag]
@ -998,7 +1000,7 @@ class ServeController:
# after pushing the update. # after pushing the update.
self._checkpoint() self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop( await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.current_state) self.backend_state)
await self.actor_reconciler.backend_control_loop() await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed() self.notify_replica_handles_changed()
@ -1008,24 +1010,24 @@ class ServeController:
config_options: BackendConfig) -> UUID: config_options: BackendConfig) -> UUID:
"""Set the config for the specified backend.""" """Set the config for the specified backend."""
async with self.write_lock: async with self.write_lock:
assert (self.current_state.get_backend(backend_tag) assert (self.backend_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag) ), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, BackendConfig) assert isinstance(config_options, BackendConfig)
stored_backend_config = self.current_state.get_backend( stored_backend_config = self.backend_state.get_backend(
backend_tag).backend_config backend_tag).backend_config
backend_config = stored_backend_config.copy( backend_config = stored_backend_config.copy(
update=config_options.dict(exclude_unset=True)) update=config_options.dict(exclude_unset=True))
backend_config._validate_complete() backend_config._validate_complete()
self.current_state.get_backend( self.backend_state.get_backend(
backend_tag).backend_config = backend_config backend_tag).backend_config = backend_config
backend_info = self.current_state.get_backend(backend_tag) backend_info = self.backend_state.get_backend(backend_tag)
# Scale the replicas with the new configuration. # Scale the replicas with the new configuration.
# This should be to run the control loop # This should be to run the control loop
self.actor_reconciler._scale_backend_replicas( self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag, self.backend_state.backends, backend_tag,
backend_config.num_replicas) backend_config.num_replicas)
return_uuid = self._create_event_with_result({ return_uuid = self._create_event_with_result({
@ -1040,7 +1042,7 @@ class ServeController:
# (particularly for setting max_batch_size). # (particularly for setting max_batch_size).
await self.actor_reconciler._enqueue_pending_scale_changes_loop( await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.current_state) self.backend_state)
await self.actor_reconciler.backend_control_loop() await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed() self.notify_replica_handles_changed()
@ -1049,9 +1051,9 @@ class ServeController:
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig: def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
"""Get the current config for the specified backend.""" """Get the current config for the specified backend."""
assert (self.current_state.get_backend(backend_tag) assert (self.backend_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag) ), "Backend {} is not registered.".format(backend_tag)
return self.current_state.get_backend(backend_tag).backend_config return self.backend_state.get_backend(backend_tag).backend_config
def get_http_config(self): def get_http_config(self):
"""Return the HTTP proxy configuration.""" """Return the HTTP proxy configuration."""