[autoscaler] Fix update/terminate race condition (#15019)

Co-authored-by: AmeerHajAli <ameerh@berkeley.edu>
This commit is contained in:
Dmitri Gekhtman 2021-04-02 14:57:02 -04:00 committed by GitHub
parent 3578d4e9d8
commit 42565d5bbe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 247 additions and 78 deletions

View file

@ -250,16 +250,17 @@ class StandardAutoscaler:
for node_type, count in to_launch.items():
self.launch_new_node(count, node_type=node_type)
nodes = self.workers()
if to_launch:
nodes = self.workers()
# Process any completed updates
completed = []
completed_nodes = []
for node_id, updater in self.updaters.items():
if not updater.is_alive():
completed.append(node_id)
if completed:
nodes_to_terminate: List[NodeID] = []
for node_id in completed:
completed_nodes.append(node_id)
if completed_nodes:
failed_nodes = []
for node_id in completed_nodes:
if self.updaters[node_id].exitcode == 0:
self.num_successful_updates[node_id] += 1
# Mark the node as active to prevent the node recovery
@ -267,20 +268,35 @@ class StandardAutoscaler:
self.load_metrics.mark_active(
self.provider.internal_ip(node_id))
else:
logger.error(f"StandardAutoscaler: {node_id}: Terminating "
"failed to setup/initialize node.")
self.event_summarizer.add(
"Removing {} nodes of type " +
self._get_node_type(node_id) + " (launch failed).",
quantity=1,
aggregate=operator.add)
nodes_to_terminate.append(node_id)
failed_nodes.append(node_id)
self.num_failed_updates[node_id] += 1
self.node_tracker.untrack(node_id)
del self.updaters[node_id]
if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
if failed_nodes:
# Some nodes in failed_nodes may have been terminated
# during an update (for being idle after missing a heartbeat).
# Only terminate currently non terminated nodes.
non_terminated_nodes = self.workers()
nodes_to_terminate: List[NodeID] = []
for node_id in failed_nodes:
if node_id in non_terminated_nodes:
nodes_to_terminate.append(node_id)
logger.error(f"StandardAutoscaler: {node_id}:"
" Terminating. Failed to setup/initialize"
" node.")
self.event_summarizer.add(
"Removing {} nodes of type " +
self._get_node_type(node_id) + " (launch failed).",
quantity=1,
aggregate=operator.add)
else:
logger.warning(f"StandardAutoscaler: {node_id}:"
" Failed to update node."
" Node has already been terminated.")
if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
# Update nodes with out-of-date files.
# TODO(edoakes): Spawning these threads directly seems to cause
@ -602,7 +618,7 @@ class StandardAutoscaler:
if TAG_RAY_USER_NODE_TYPE in node_tags:
return node_tags[TAG_RAY_USER_NODE_TYPE]
else:
return "unknown"
return "unknown_node_type"
def _get_node_type_specific_fields(self, node_id: str,
fields_key: str) -> Any:

View file

@ -152,7 +152,14 @@ class KubernetesNodeProvider(NodeProvider):
def terminate_node(self, node_id):
logger.info(log_prefix + "calling delete_namespaced_pod")
core_api().delete_namespaced_pod(node_id, self.namespace)
try:
core_api().delete_namespaced_pod(node_id, self.namespace)
except ApiException as e:
if e.status == 404:
logger.warning(log_prefix + f"Tried to delete pod {node_id},"
" but the pod was not found (404).")
else:
raise
try:
core_api().delete_namespaced_service(node_id, self.namespace)
except ApiException:

View file

@ -11,6 +11,7 @@ from unittest.mock import Mock
import yaml
import copy
from jsonschema.exceptions import ValidationError
from typing import Dict, Callable
import ray
from ray.autoscaler._private.util import prepare_config, validate_config
@ -51,85 +52,107 @@ class MockNode:
class MockProcessRunner:
def __init__(self, fail_cmds=None):
def __init__(self, fail_cmds=None, cmd_to_callback=None, print_out=False):
self.calls = []
self.cmd_to_callback = cmd_to_callback or {
} # type: Dict[str, Callable]
self.print_out = print_out
self.fail_cmds = fail_cmds or []
self.call_response = {}
self.ready_to_run = threading.Event()
self.ready_to_run.set()
self.lock = threading.RLock()
def check_call(self, cmd, *args, **kwargs):
self.ready_to_run.wait()
for token in self.fail_cmds:
if token in str(cmd):
raise CalledProcessError(1, token,
"Failing command on purpose")
self.calls.append(cmd)
with self.lock:
self.ready_to_run.wait()
self.calls.append(cmd)
if self.print_out:
print(f">>>Process runner: Executing \n {str(cmd)}")
for token in self.cmd_to_callback:
if token in str(cmd):
# Trigger a callback if token is in cmd.
# Can be used to simulate background events during a node
# update (e.g. node disconnected).
callback = self.cmd_to_callback[token]
callback()
for token in self.fail_cmds:
if token in str(cmd):
raise CalledProcessError(1, token,
"Failing command on purpose")
def check_output(self, cmd):
self.check_call(cmd)
return_string = "command-output"
key_to_shrink = None
for pattern, response_list in self.call_response.items():
if pattern in str(cmd):
return_string = response_list[0]
key_to_shrink = pattern
break
if key_to_shrink:
self.call_response[key_to_shrink] = self.call_response[
key_to_shrink][1:]
if len(self.call_response[key_to_shrink]) == 0:
del self.call_response[key_to_shrink]
with self.lock:
self.check_call(cmd)
return_string = "command-output"
key_to_shrink = None
for pattern, response_list in self.call_response.items():
if pattern in str(cmd):
return_string = response_list[0]
key_to_shrink = pattern
break
if key_to_shrink:
self.call_response[key_to_shrink] = self.call_response[
key_to_shrink][1:]
if len(self.call_response[key_to_shrink]) == 0:
del self.call_response[key_to_shrink]
return return_string.encode()
return return_string.encode()
def assert_has_call(self, ip, pattern=None, exact=None):
assert pattern or exact, \
"Must specify either a pattern or exact match."
out = ""
if pattern is not None:
with self.lock:
assert pattern or exact, \
"Must specify either a pattern or exact match."
out = ""
if pattern is not None:
for cmd in self.command_history():
if ip in cmd:
out += cmd
out += "\n"
if pattern in out:
return True
else:
raise Exception(
f"Did not find [{pattern}] in [{out}] for ip={ip}."
f"\n\nFull output: {self.command_history()}")
elif exact is not None:
exact_cmd = " ".join(exact)
for cmd in self.command_history():
if ip in cmd:
out += cmd
out += "\n"
if cmd == exact_cmd:
return True
raise Exception(
f"Did not find [{exact_cmd}] in [{out}] for ip={ip}."
f"\n\nFull output: {self.command_history()}")
def assert_not_has_call(self, ip, pattern):
with self.lock:
out = ""
for cmd in self.command_history():
if ip in cmd:
out += cmd
out += "\n"
if pattern in out:
return True
raise Exception("Found [{}] in [{}] for {}".format(
pattern, out, ip))
else:
raise Exception(
f"Did not find [{pattern}] in [{out}] for ip={ip}."
f"\n\nFull output: {self.command_history()}")
elif exact is not None:
exact_cmd = " ".join(exact)
for cmd in self.command_history():
if ip in cmd:
out += cmd
out += "\n"
if cmd == exact_cmd:
return True
raise Exception(
f"Did not find [{exact_cmd}] in [{out}] for ip={ip}."
f"\n\nFull output: {self.command_history()}")
def assert_not_has_call(self, ip, pattern):
out = ""
for cmd in self.command_history():
if ip in cmd:
out += cmd
out += "\n"
if pattern in out:
raise Exception("Found [{}] in [{}] for {}".format(
pattern, out, ip))
else:
return True
return True
def clear_history(self):
self.calls = []
with self.lock:
self.calls = []
def command_history(self):
return [" ".join(cmd) for cmd in self.calls]
with self.lock:
return [" ".join(cmd) for cmd in self.calls]
def respond_to_call(self, pattern, response_list):
self.call_response[pattern] = response_list
with self.lock:
self.call_response[pattern] = response_list
class MockProvider(NodeProvider):
@ -356,13 +379,13 @@ class AutoscalingTest(unittest.TestCase):
shutil.rmtree(self.tmpdir)
ray.shutdown()
def waitFor(self, condition, num_retries=50):
def waitFor(self, condition, num_retries=50, fail_msg=None):
for _ in range(num_retries):
if condition():
return
time.sleep(.1)
raise RayTestTimeoutException(
"Timed out waiting for {}".format(condition))
fail_msg = fail_msg or "Timed out waiting for {}".format(condition)
raise RayTestTimeoutException(fail_msg)
def waitForNodes(self, expected, comparison=None, tag_filters={}):
MAX_ITER = 50
@ -371,7 +394,7 @@ class AutoscalingTest(unittest.TestCase):
if comparison is None:
comparison = self.assertEqual
try:
comparison(n, expected)
comparison(n, expected, msg="Unexpected node quantity.")
return
except Exception:
if i == MAX_ITER - 1:
@ -2218,6 +2241,118 @@ MemAvailable: 33000000 kB
assert node == 1
def testNodeTerminatedDuringUpdate(self):
"""
Tests autoscaler handling a node getting terminated during an update
triggered by the node missing a heartbeat.
Extension of testRecoverUnhealthyWorkers.
In this test, two nodes miss a heartbeat.
One of them (node 0) is terminated during its recovery update.
The other (node 1) just fails its update.
When processing completed updates, the autoscaler terminates node 1
but does not try to terminate node 0 again.
"""
cluster_config = copy.deepcopy(MOCK_DEFAULT_CONFIG)
cluster_config["available_node_types"]["ray.worker.default"][
"min_workers"] = 2
cluster_config["worker_start_ray_commands"] = ["ray_start_cmd"]
# Don't need the extra node type or a docker config.
cluster_config["head_node_type"] = ["ray.worker.default"]
del cluster_config["available_node_types"]["ray.head.default"]
del cluster_config["docker"]
config_path = self.write_config(cluster_config)
self.provider = MockProvider()
runner = MockProcessRunner()
lm = LoadMetrics()
autoscaler = StandardAutoscaler(
config_path,
lm,
max_failures=0,
process_runner=runner,
update_interval_s=0)
# Scale up to two up-to-date workers
autoscaler.update()
self.waitForNodes(2)
self.provider.finish_starting_nodes()
autoscaler.update()
self.waitForNodes(
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
# Mark both nodes as unhealthy
for _ in range(5):
if autoscaler.updaters:
time.sleep(0.05)
autoscaler.update()
assert not autoscaler.updaters
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
lm.last_heartbeat_time_by_ip["172.0.0.1"] = 0
# Set up process runner to terminate worker 0 during missed heartbeat
# recovery and also cause the updater to fail.
def terminate_worker_zero():
self.provider.terminate_node(0)
autoscaler.process_runner = MockProcessRunner(
fail_cmds=["ray_start_cmd"],
cmd_to_callback={"ray_start_cmd": terminate_worker_zero})
num_calls = len(autoscaler.process_runner.calls)
autoscaler.update()
# Wait for updaters spawned by last autoscaler update to finish.
self.waitFor(
lambda: all(not updater.is_alive()
for updater in autoscaler.updaters.values()),
num_retries=500,
fail_msg="Last round of updaters didn't complete on time."
)
# Check that updaters processed some commands in the last autoscaler
# update.
assert len(autoscaler.process_runner.calls) > num_calls,\
"Did not get additional process runner calls on last autoscaler"\
" update."
# Missed heartbeat triggered recovery for both nodes.
events = autoscaler.event_summarizer.summary()
assert (
"Restarting 2 nodes of type "
"ray.worker.default (lost contact with raylet)." in events), events
# Node 0 was terminated during the last update.
# Node 1's updater failed, but node 1 won't be terminated until the
# next autoscaler update.
assert 0 not in autoscaler.workers(), "Node zero still non-terminated."
assert not self.provider.is_terminated(1),\
"Node one terminated prematurely."
autoscaler.update()
# Failed updates processed are now processed.
assert autoscaler.num_failed_updates[0] == 1,\
"Node zero update failure not registered"
assert autoscaler.num_failed_updates[1] == 1,\
"Node one update failure not registered"
# Completed-update-processing logic should have terminated node 1.
assert self.provider.is_terminated(1), "Node 1 not terminated on time."
events = autoscaler.event_summarizer.summary()
# Just one node (node_id 1) terminated in the last update.
# Validates that we didn't try to double-terminate node 0.
assert ("Removing 1 nodes of type "
"ray.worker.default (launch failed)." in events), events
# To be more explicit,
assert ("Removing 2 nodes of type "
"ray.worker.default (launch failed)." not in events), events
# Should get two new nodes after the next update.
autoscaler.update()
self.waitForNodes(2)
assert set(autoscaler.workers()) == {2, 3},\
"Unexpected node_ids"
if __name__ == "__main__":
import sys

View file

@ -12,6 +12,9 @@ import kubernetes
import pytest
import yaml
from ray.autoscaler._private.kubernetes.node_provider import\
KubernetesNodeProvider
IMAGE_ENV = "KUBERNETES_OPERATOR_TEST_IMAGE"
IMAGE = os.getenv(IMAGE_ENV, "rayproject/ray:nightly")
@ -97,6 +100,14 @@ def get_operator_config_path(file_name):
class KubernetesOperatorTest(unittest.TestCase):
def test_examples(self):
# Validate terminate_node error handling
provider = KubernetesNodeProvider({
"namespace": NAMESPACE
}, "default_cluster_name")
# 404 caught, no error
provider.terminate_node("no-such-node")
with tempfile.NamedTemporaryFile("w+") as example_cluster_file, \
tempfile.NamedTemporaryFile("w+") as example_cluster2_file,\
tempfile.NamedTemporaryFile("w+") as operator_file,\