[autoscaler] Fix update/terminate race condition (#15019)

Co-authored-by: AmeerHajAli <ameerh@berkeley.edu>
This commit is contained in:
Dmitri Gekhtman 2021-04-02 14:57:02 -04:00 committed by GitHub
parent 3578d4e9d8
commit 42565d5bbe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 247 additions and 78 deletions

View file

@ -250,16 +250,17 @@ class StandardAutoscaler:
for node_type, count in to_launch.items(): for node_type, count in to_launch.items():
self.launch_new_node(count, node_type=node_type) self.launch_new_node(count, node_type=node_type)
if to_launch:
nodes = self.workers() nodes = self.workers()
# Process any completed updates # Process any completed updates
completed = [] completed_nodes = []
for node_id, updater in self.updaters.items(): for node_id, updater in self.updaters.items():
if not updater.is_alive(): if not updater.is_alive():
completed.append(node_id) completed_nodes.append(node_id)
if completed: if completed_nodes:
nodes_to_terminate: List[NodeID] = [] failed_nodes = []
for node_id in completed: for node_id in completed_nodes:
if self.updaters[node_id].exitcode == 0: if self.updaters[node_id].exitcode == 0:
self.num_successful_updates[node_id] += 1 self.num_successful_updates[node_id] += 1
# Mark the node as active to prevent the node recovery # Mark the node as active to prevent the node recovery
@ -267,19 +268,34 @@ class StandardAutoscaler:
self.load_metrics.mark_active( self.load_metrics.mark_active(
self.provider.internal_ip(node_id)) self.provider.internal_ip(node_id))
else: else:
logger.error(f"StandardAutoscaler: {node_id}: Terminating " failed_nodes.append(node_id)
"failed to setup/initialize node.") self.num_failed_updates[node_id] += 1
self.node_tracker.untrack(node_id)
del self.updaters[node_id]
if failed_nodes:
# Some nodes in failed_nodes may have been terminated
# during an update (for being idle after missing a heartbeat).
# Only terminate currently non terminated nodes.
non_terminated_nodes = self.workers()
nodes_to_terminate: List[NodeID] = []
for node_id in failed_nodes:
if node_id in non_terminated_nodes:
nodes_to_terminate.append(node_id)
logger.error(f"StandardAutoscaler: {node_id}:"
" Terminating. Failed to setup/initialize"
" node.")
self.event_summarizer.add( self.event_summarizer.add(
"Removing {} nodes of type " + "Removing {} nodes of type " +
self._get_node_type(node_id) + " (launch failed).", self._get_node_type(node_id) + " (launch failed).",
quantity=1, quantity=1,
aggregate=operator.add) aggregate=operator.add)
nodes_to_terminate.append(node_id) else:
self.num_failed_updates[node_id] += 1 logger.warning(f"StandardAutoscaler: {node_id}:"
del self.updaters[node_id] " Failed to update node."
" Node has already been terminated.")
if nodes_to_terminate: if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate) self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers() nodes = self.workers()
# Update nodes with out-of-date files. # Update nodes with out-of-date files.
@ -602,7 +618,7 @@ class StandardAutoscaler:
if TAG_RAY_USER_NODE_TYPE in node_tags: if TAG_RAY_USER_NODE_TYPE in node_tags:
return node_tags[TAG_RAY_USER_NODE_TYPE] return node_tags[TAG_RAY_USER_NODE_TYPE]
else: else:
return "unknown" return "unknown_node_type"
def _get_node_type_specific_fields(self, node_id: str, def _get_node_type_specific_fields(self, node_id: str,
fields_key: str) -> Any: fields_key: str) -> Any:

View file

@ -152,7 +152,14 @@ class KubernetesNodeProvider(NodeProvider):
def terminate_node(self, node_id): def terminate_node(self, node_id):
logger.info(log_prefix + "calling delete_namespaced_pod") logger.info(log_prefix + "calling delete_namespaced_pod")
try:
core_api().delete_namespaced_pod(node_id, self.namespace) core_api().delete_namespaced_pod(node_id, self.namespace)
except ApiException as e:
if e.status == 404:
logger.warning(log_prefix + f"Tried to delete pod {node_id},"
" but the pod was not found (404).")
else:
raise
try: try:
core_api().delete_namespaced_service(node_id, self.namespace) core_api().delete_namespaced_service(node_id, self.namespace)
except ApiException: except ApiException:

View file

@ -11,6 +11,7 @@ from unittest.mock import Mock
import yaml import yaml
import copy import copy
from jsonschema.exceptions import ValidationError from jsonschema.exceptions import ValidationError
from typing import Dict, Callable
import ray import ray
from ray.autoscaler._private.util import prepare_config, validate_config from ray.autoscaler._private.util import prepare_config, validate_config
@ -51,22 +52,39 @@ class MockNode:
class MockProcessRunner: class MockProcessRunner:
def __init__(self, fail_cmds=None): def __init__(self, fail_cmds=None, cmd_to_callback=None, print_out=False):
self.calls = [] self.calls = []
self.cmd_to_callback = cmd_to_callback or {
} # type: Dict[str, Callable]
self.print_out = print_out
self.fail_cmds = fail_cmds or [] self.fail_cmds = fail_cmds or []
self.call_response = {} self.call_response = {}
self.ready_to_run = threading.Event() self.ready_to_run = threading.Event()
self.ready_to_run.set() self.ready_to_run.set()
self.lock = threading.RLock()
def check_call(self, cmd, *args, **kwargs): def check_call(self, cmd, *args, **kwargs):
with self.lock:
self.ready_to_run.wait() self.ready_to_run.wait()
self.calls.append(cmd)
if self.print_out:
print(f">>>Process runner: Executing \n {str(cmd)}")
for token in self.cmd_to_callback:
if token in str(cmd):
# Trigger a callback if token is in cmd.
# Can be used to simulate background events during a node
# update (e.g. node disconnected).
callback = self.cmd_to_callback[token]
callback()
for token in self.fail_cmds: for token in self.fail_cmds:
if token in str(cmd): if token in str(cmd):
raise CalledProcessError(1, token, raise CalledProcessError(1, token,
"Failing command on purpose") "Failing command on purpose")
self.calls.append(cmd)
def check_output(self, cmd): def check_output(self, cmd):
with self.lock:
self.check_call(cmd) self.check_call(cmd)
return_string = "command-output" return_string = "command-output"
key_to_shrink = None key_to_shrink = None
@ -84,6 +102,7 @@ class MockProcessRunner:
return return_string.encode() return return_string.encode()
def assert_has_call(self, ip, pattern=None, exact=None): def assert_has_call(self, ip, pattern=None, exact=None):
with self.lock:
assert pattern or exact, \ assert pattern or exact, \
"Must specify either a pattern or exact match." "Must specify either a pattern or exact match."
out = "" out = ""
@ -111,6 +130,7 @@ class MockProcessRunner:
f"\n\nFull output: {self.command_history()}") f"\n\nFull output: {self.command_history()}")
def assert_not_has_call(self, ip, pattern): def assert_not_has_call(self, ip, pattern):
with self.lock:
out = "" out = ""
for cmd in self.command_history(): for cmd in self.command_history():
if ip in cmd: if ip in cmd:
@ -123,12 +143,15 @@ class MockProcessRunner:
return True return True
def clear_history(self): def clear_history(self):
with self.lock:
self.calls = [] self.calls = []
def command_history(self): def command_history(self):
with self.lock:
return [" ".join(cmd) for cmd in self.calls] return [" ".join(cmd) for cmd in self.calls]
def respond_to_call(self, pattern, response_list): def respond_to_call(self, pattern, response_list):
with self.lock:
self.call_response[pattern] = response_list self.call_response[pattern] = response_list
@ -356,13 +379,13 @@ class AutoscalingTest(unittest.TestCase):
shutil.rmtree(self.tmpdir) shutil.rmtree(self.tmpdir)
ray.shutdown() ray.shutdown()
def waitFor(self, condition, num_retries=50): def waitFor(self, condition, num_retries=50, fail_msg=None):
for _ in range(num_retries): for _ in range(num_retries):
if condition(): if condition():
return return
time.sleep(.1) time.sleep(.1)
raise RayTestTimeoutException( fail_msg = fail_msg or "Timed out waiting for {}".format(condition)
"Timed out waiting for {}".format(condition)) raise RayTestTimeoutException(fail_msg)
def waitForNodes(self, expected, comparison=None, tag_filters={}): def waitForNodes(self, expected, comparison=None, tag_filters={}):
MAX_ITER = 50 MAX_ITER = 50
@ -371,7 +394,7 @@ class AutoscalingTest(unittest.TestCase):
if comparison is None: if comparison is None:
comparison = self.assertEqual comparison = self.assertEqual
try: try:
comparison(n, expected) comparison(n, expected, msg="Unexpected node quantity.")
return return
except Exception: except Exception:
if i == MAX_ITER - 1: if i == MAX_ITER - 1:
@ -2218,6 +2241,118 @@ MemAvailable: 33000000 kB
assert node == 1 assert node == 1
def testNodeTerminatedDuringUpdate(self):
"""
Tests autoscaler handling a node getting terminated during an update
triggered by the node missing a heartbeat.
Extension of testRecoverUnhealthyWorkers.
In this test, two nodes miss a heartbeat.
One of them (node 0) is terminated during its recovery update.
The other (node 1) just fails its update.
When processing completed updates, the autoscaler terminates node 1
but does not try to terminate node 0 again.
"""
cluster_config = copy.deepcopy(MOCK_DEFAULT_CONFIG)
cluster_config["available_node_types"]["ray.worker.default"][
"min_workers"] = 2
cluster_config["worker_start_ray_commands"] = ["ray_start_cmd"]
# Don't need the extra node type or a docker config.
cluster_config["head_node_type"] = ["ray.worker.default"]
del cluster_config["available_node_types"]["ray.head.default"]
del cluster_config["docker"]
config_path = self.write_config(cluster_config)
self.provider = MockProvider()
runner = MockProcessRunner()
lm = LoadMetrics()
autoscaler = StandardAutoscaler(
config_path,
lm,
max_failures=0,
process_runner=runner,
update_interval_s=0)
# Scale up to two up-to-date workers
autoscaler.update()
self.waitForNodes(2)
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
# Mark both nodes as unhealthy
for _ in range(5):
if autoscaler.updaters:
time.sleep(0.05)
autoscaler.update()
assert not autoscaler.updaters
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
lm.last_heartbeat_time_by_ip["172.0.0.1"] = 0
# Set up process runner to terminate worker 0 during missed heartbeat
# recovery and also cause the updater to fail.
def terminate_worker_zero():
self.provider.terminate_node(0)
autoscaler.process_runner = MockProcessRunner(
fail_cmds=["ray_start_cmd"],
cmd_to_callback={"ray_start_cmd": terminate_worker_zero})
num_calls = len(autoscaler.process_runner.calls)
autoscaler.update()
# Wait for updaters spawned by last autoscaler update to finish.
self.waitFor(
lambda: all(not updater.is_alive()
for updater in autoscaler.updaters.values()),
num_retries=500,
fail_msg="Last round of updaters didn't complete on time."
)
# Check that updaters processed some commands in the last autoscaler
# update.
assert len(autoscaler.process_runner.calls) > num_calls,\
"Did not get additional process runner calls on last autoscaler"\
" update."
# Missed heartbeat triggered recovery for both nodes.
events = autoscaler.event_summarizer.summary()
assert (
"Restarting 2 nodes of type "
"ray.worker.default (lost contact with raylet)." in events), events
# Node 0 was terminated during the last update.
# Node 1's updater failed, but node 1 won't be terminated until the
# next autoscaler update.
assert 0 not in autoscaler.workers(), "Node zero still non-terminated."
assert not self.provider.is_terminated(1),\
"Node one terminated prematurely."
autoscaler.update()
# Failed updates processed are now processed.
assert autoscaler.num_failed_updates[0] == 1,\
"Node zero update failure not registered"
assert autoscaler.num_failed_updates[1] == 1,\
"Node one update failure not registered"
# Completed-update-processing logic should have terminated node 1.
assert self.provider.is_terminated(1), "Node 1 not terminated on time."
events = autoscaler.event_summarizer.summary()
# Just one node (node_id 1) terminated in the last update.
# Validates that we didn't try to double-terminate node 0.
assert ("Removing 1 nodes of type "
"ray.worker.default (launch failed)." in events), events
# To be more explicit,
assert ("Removing 2 nodes of type "
"ray.worker.default (launch failed)." not in events), events
# Should get two new nodes after the next update.
autoscaler.update()
self.waitForNodes(2)
assert set(autoscaler.workers()) == {2, 3},\
"Unexpected node_ids"
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys

View file

@ -12,6 +12,9 @@ import kubernetes
import pytest import pytest
import yaml import yaml
from ray.autoscaler._private.kubernetes.node_provider import\
KubernetesNodeProvider
IMAGE_ENV = "KUBERNETES_OPERATOR_TEST_IMAGE" IMAGE_ENV = "KUBERNETES_OPERATOR_TEST_IMAGE"
IMAGE = os.getenv(IMAGE_ENV, "rayproject/ray:nightly") IMAGE = os.getenv(IMAGE_ENV, "rayproject/ray:nightly")
@ -97,6 +100,14 @@ def get_operator_config_path(file_name):
class KubernetesOperatorTest(unittest.TestCase): class KubernetesOperatorTest(unittest.TestCase):
def test_examples(self): def test_examples(self):
# Validate terminate_node error handling
provider = KubernetesNodeProvider({
"namespace": NAMESPACE
}, "default_cluster_name")
# 404 caught, no error
provider.terminate_node("no-such-node")
with tempfile.NamedTemporaryFile("w+") as example_cluster_file, \ with tempfile.NamedTemporaryFile("w+") as example_cluster_file, \
tempfile.NamedTemporaryFile("w+") as example_cluster2_file,\ tempfile.NamedTemporaryFile("w+") as example_cluster2_file,\
tempfile.NamedTemporaryFile("w+") as operator_file,\ tempfile.NamedTemporaryFile("w+") as operator_file,\