[autoscaler][hotfix] Update node list after terminating unhealthy nodes (#17992)

* Update nodes; update test.

* consistency

* lint
This commit is contained in:
Dmitri Gekhtman 2021-08-22 18:22:10 -04:00 committed by GitHub
parent 5ca28b1cc8
commit 13d5d0f9ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 5 deletions

View file

@ -396,7 +396,10 @@ class StandardAutoscaler:
if self.disable_node_updaters: if self.disable_node_updaters:
# If updaters are unavailable, terminate unhealthy nodes. # If updaters are unavailable, terminate unhealthy nodes.
self.terminate_unhealthy_nodes(nodes, now) nodes_to_terminate = self.get_unhealthy_nodes(nodes, now)
if nodes_to_terminate:
self._terminate_nodes_and_cleanup(nodes_to_terminate)
nodes = self.workers()
else: else:
# Attempt to recover unhealthy nodes # Attempt to recover unhealthy nodes
for node_id in nodes: for node_id in nodes:
@ -716,8 +719,10 @@ class StandardAutoscaler:
return True return True
return False return False
def terminate_unhealthy_nodes(self, nodes: List[NodeID], now: float): def get_unhealthy_nodes(self, nodes: List[NodeID],
"""Terminate nodes for which we haven't received a heartbeat on time. now: float) -> List[NodeID]:
"""Determine nodes for which we haven't received a heartbeat on time.
These nodes are subsequently terminated.
Used when node updaters are not available for recovery. Used when node updaters are not available for recovery.
""" """
@ -748,8 +753,7 @@ class StandardAutoscaler:
aggregate=operator.add) aggregate=operator.add)
nodes_to_terminate.append(node_id) nodes_to_terminate.append(node_id)
if nodes_to_terminate: return nodes_to_terminate
self._terminate_nodes_and_cleanup(nodes_to_terminate)
def recover_if_needed(self, node_id, now): def recover_if_needed(self, node_id, now):
if not self.can_update(node_id): if not self.can_update(node_id):

View file

@ -204,6 +204,10 @@ class MockProvider(NodeProvider):
return self.mock_nodes[node_id].state in ["stopped", "terminated"] return self.mock_nodes[node_id].state in ["stopped", "terminated"]
def node_tags(self, node_id): def node_tags(self, node_id):
# Don't assume that node providers can retrieve tags from
# terminated nodes.
if self.is_terminated(node_id):
raise Exception(f"The node with id {node_id} has been terminated!")
with self.lock: with self.lock:
return self.mock_nodes[node_id].tags return self.mock_nodes[node_id].tags