mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[autoscaler] Fix update/terminate race condition (#15019)
Co-authored-by: AmeerHajAli <ameerh@berkeley.edu>
This commit is contained in:
parent
3578d4e9d8
commit
42565d5bbe
4 changed files with 247 additions and 78 deletions
|
@ -250,16 +250,17 @@ class StandardAutoscaler:
|
|||
for node_type, count in to_launch.items():
|
||||
self.launch_new_node(count, node_type=node_type)
|
||||
|
||||
nodes = self.workers()
|
||||
if to_launch:
|
||||
nodes = self.workers()
|
||||
|
||||
# Process any completed updates
|
||||
completed = []
|
||||
completed_nodes = []
|
||||
for node_id, updater in self.updaters.items():
|
||||
if not updater.is_alive():
|
||||
completed.append(node_id)
|
||||
if completed:
|
||||
nodes_to_terminate: List[NodeID] = []
|
||||
for node_id in completed:
|
||||
completed_nodes.append(node_id)
|
||||
if completed_nodes:
|
||||
failed_nodes = []
|
||||
for node_id in completed_nodes:
|
||||
if self.updaters[node_id].exitcode == 0:
|
||||
self.num_successful_updates[node_id] += 1
|
||||
# Mark the node as active to prevent the node recovery
|
||||
|
@ -267,20 +268,35 @@ class StandardAutoscaler:
|
|||
self.load_metrics.mark_active(
|
||||
self.provider.internal_ip(node_id))
|
||||
else:
|
||||
logger.error(f"StandardAutoscaler: {node_id}: Terminating "
|
||||
"failed to setup/initialize node.")
|
||||
self.event_summarizer.add(
|
||||
"Removing {} nodes of type " +
|
||||
self._get_node_type(node_id) + " (launch failed).",
|
||||
quantity=1,
|
||||
aggregate=operator.add)
|
||||
nodes_to_terminate.append(node_id)
|
||||
failed_nodes.append(node_id)
|
||||
self.num_failed_updates[node_id] += 1
|
||||
self.node_tracker.untrack(node_id)
|
||||
del self.updaters[node_id]
|
||||
if nodes_to_terminate:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
|
||||
nodes = self.workers()
|
||||
if failed_nodes:
|
||||
# Some nodes in failed_nodes may have been terminated
|
||||
# during an update (for being idle after missing a heartbeat).
|
||||
# Only terminate currently non terminated nodes.
|
||||
non_terminated_nodes = self.workers()
|
||||
nodes_to_terminate: List[NodeID] = []
|
||||
for node_id in failed_nodes:
|
||||
if node_id in non_terminated_nodes:
|
||||
nodes_to_terminate.append(node_id)
|
||||
logger.error(f"StandardAutoscaler: {node_id}:"
|
||||
" Terminating. Failed to setup/initialize"
|
||||
" node.")
|
||||
self.event_summarizer.add(
|
||||
"Removing {} nodes of type " +
|
||||
self._get_node_type(node_id) + " (launch failed).",
|
||||
quantity=1,
|
||||
aggregate=operator.add)
|
||||
else:
|
||||
logger.warning(f"StandardAutoscaler: {node_id}:"
|
||||
" Failed to update node."
|
||||
" Node has already been terminated.")
|
||||
if nodes_to_terminate:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
|
||||
# Update nodes with out-of-date files.
|
||||
# TODO(edoakes): Spawning these threads directly seems to cause
|
||||
|
@ -602,7 +618,7 @@ class StandardAutoscaler:
|
|||
if TAG_RAY_USER_NODE_TYPE in node_tags:
|
||||
return node_tags[TAG_RAY_USER_NODE_TYPE]
|
||||
else:
|
||||
return "unknown"
|
||||
return "unknown_node_type"
|
||||
|
||||
def _get_node_type_specific_fields(self, node_id: str,
|
||||
fields_key: str) -> Any:
|
||||
|
|
|
@ -152,7 +152,14 @@ class KubernetesNodeProvider(NodeProvider):
|
|||
|
||||
def terminate_node(self, node_id):
|
||||
logger.info(log_prefix + "calling delete_namespaced_pod")
|
||||
core_api().delete_namespaced_pod(node_id, self.namespace)
|
||||
try:
|
||||
core_api().delete_namespaced_pod(node_id, self.namespace)
|
||||
except ApiException as e:
|
||||
if e.status == 404:
|
||||
logger.warning(log_prefix + f"Tried to delete pod {node_id},"
|
||||
" but the pod was not found (404).")
|
||||
else:
|
||||
raise
|
||||
try:
|
||||
core_api().delete_namespaced_service(node_id, self.namespace)
|
||||
except ApiException:
|
||||
|
|
|
@ -11,6 +11,7 @@ from unittest.mock import Mock
|
|||
import yaml
|
||||
import copy
|
||||
from jsonschema.exceptions import ValidationError
|
||||
from typing import Dict, Callable
|
||||
|
||||
import ray
|
||||
from ray.autoscaler._private.util import prepare_config, validate_config
|
||||
|
@ -51,85 +52,107 @@ class MockNode:
|
|||
|
||||
|
||||
class MockProcessRunner:
|
||||
def __init__(self, fail_cmds=None):
|
||||
def __init__(self, fail_cmds=None, cmd_to_callback=None, print_out=False):
|
||||
self.calls = []
|
||||
self.cmd_to_callback = cmd_to_callback or {
|
||||
} # type: Dict[str, Callable]
|
||||
self.print_out = print_out
|
||||
self.fail_cmds = fail_cmds or []
|
||||
self.call_response = {}
|
||||
self.ready_to_run = threading.Event()
|
||||
self.ready_to_run.set()
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def check_call(self, cmd, *args, **kwargs):
|
||||
self.ready_to_run.wait()
|
||||
for token in self.fail_cmds:
|
||||
if token in str(cmd):
|
||||
raise CalledProcessError(1, token,
|
||||
"Failing command on purpose")
|
||||
self.calls.append(cmd)
|
||||
with self.lock:
|
||||
self.ready_to_run.wait()
|
||||
self.calls.append(cmd)
|
||||
if self.print_out:
|
||||
print(f">>>Process runner: Executing \n {str(cmd)}")
|
||||
for token in self.cmd_to_callback:
|
||||
if token in str(cmd):
|
||||
# Trigger a callback if token is in cmd.
|
||||
# Can be used to simulate background events during a node
|
||||
# update (e.g. node disconnected).
|
||||
callback = self.cmd_to_callback[token]
|
||||
callback()
|
||||
|
||||
for token in self.fail_cmds:
|
||||
if token in str(cmd):
|
||||
raise CalledProcessError(1, token,
|
||||
"Failing command on purpose")
|
||||
|
||||
def check_output(self, cmd):
|
||||
self.check_call(cmd)
|
||||
return_string = "command-output"
|
||||
key_to_shrink = None
|
||||
for pattern, response_list in self.call_response.items():
|
||||
if pattern in str(cmd):
|
||||
return_string = response_list[0]
|
||||
key_to_shrink = pattern
|
||||
break
|
||||
if key_to_shrink:
|
||||
self.call_response[key_to_shrink] = self.call_response[
|
||||
key_to_shrink][1:]
|
||||
if len(self.call_response[key_to_shrink]) == 0:
|
||||
del self.call_response[key_to_shrink]
|
||||
with self.lock:
|
||||
self.check_call(cmd)
|
||||
return_string = "command-output"
|
||||
key_to_shrink = None
|
||||
for pattern, response_list in self.call_response.items():
|
||||
if pattern in str(cmd):
|
||||
return_string = response_list[0]
|
||||
key_to_shrink = pattern
|
||||
break
|
||||
if key_to_shrink:
|
||||
self.call_response[key_to_shrink] = self.call_response[
|
||||
key_to_shrink][1:]
|
||||
if len(self.call_response[key_to_shrink]) == 0:
|
||||
del self.call_response[key_to_shrink]
|
||||
|
||||
return return_string.encode()
|
||||
return return_string.encode()
|
||||
|
||||
def assert_has_call(self, ip, pattern=None, exact=None):
|
||||
assert pattern or exact, \
|
||||
"Must specify either a pattern or exact match."
|
||||
out = ""
|
||||
if pattern is not None:
|
||||
with self.lock:
|
||||
assert pattern or exact, \
|
||||
"Must specify either a pattern or exact match."
|
||||
out = ""
|
||||
if pattern is not None:
|
||||
for cmd in self.command_history():
|
||||
if ip in cmd:
|
||||
out += cmd
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
return True
|
||||
else:
|
||||
raise Exception(
|
||||
f"Did not find [{pattern}] in [{out}] for ip={ip}."
|
||||
f"\n\nFull output: {self.command_history()}")
|
||||
elif exact is not None:
|
||||
exact_cmd = " ".join(exact)
|
||||
for cmd in self.command_history():
|
||||
if ip in cmd:
|
||||
out += cmd
|
||||
out += "\n"
|
||||
if cmd == exact_cmd:
|
||||
return True
|
||||
raise Exception(
|
||||
f"Did not find [{exact_cmd}] in [{out}] for ip={ip}."
|
||||
f"\n\nFull output: {self.command_history()}")
|
||||
|
||||
def assert_not_has_call(self, ip, pattern):
|
||||
with self.lock:
|
||||
out = ""
|
||||
for cmd in self.command_history():
|
||||
if ip in cmd:
|
||||
out += cmd
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
return True
|
||||
raise Exception("Found [{}] in [{}] for {}".format(
|
||||
pattern, out, ip))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Did not find [{pattern}] in [{out}] for ip={ip}."
|
||||
f"\n\nFull output: {self.command_history()}")
|
||||
elif exact is not None:
|
||||
exact_cmd = " ".join(exact)
|
||||
for cmd in self.command_history():
|
||||
if ip in cmd:
|
||||
out += cmd
|
||||
out += "\n"
|
||||
if cmd == exact_cmd:
|
||||
return True
|
||||
raise Exception(
|
||||
f"Did not find [{exact_cmd}] in [{out}] for ip={ip}."
|
||||
f"\n\nFull output: {self.command_history()}")
|
||||
|
||||
def assert_not_has_call(self, ip, pattern):
|
||||
out = ""
|
||||
for cmd in self.command_history():
|
||||
if ip in cmd:
|
||||
out += cmd
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
raise Exception("Found [{}] in [{}] for {}".format(
|
||||
pattern, out, ip))
|
||||
else:
|
||||
return True
|
||||
return True
|
||||
|
||||
def clear_history(self):
|
||||
self.calls = []
|
||||
with self.lock:
|
||||
self.calls = []
|
||||
|
||||
def command_history(self):
|
||||
return [" ".join(cmd) for cmd in self.calls]
|
||||
with self.lock:
|
||||
return [" ".join(cmd) for cmd in self.calls]
|
||||
|
||||
def respond_to_call(self, pattern, response_list):
|
||||
self.call_response[pattern] = response_list
|
||||
with self.lock:
|
||||
self.call_response[pattern] = response_list
|
||||
|
||||
|
||||
class MockProvider(NodeProvider):
|
||||
|
@ -356,13 +379,13 @@ class AutoscalingTest(unittest.TestCase):
|
|||
shutil.rmtree(self.tmpdir)
|
||||
ray.shutdown()
|
||||
|
||||
def waitFor(self, condition, num_retries=50):
|
||||
def waitFor(self, condition, num_retries=50, fail_msg=None):
|
||||
for _ in range(num_retries):
|
||||
if condition():
|
||||
return
|
||||
time.sleep(.1)
|
||||
raise RayTestTimeoutException(
|
||||
"Timed out waiting for {}".format(condition))
|
||||
fail_msg = fail_msg or "Timed out waiting for {}".format(condition)
|
||||
raise RayTestTimeoutException(fail_msg)
|
||||
|
||||
def waitForNodes(self, expected, comparison=None, tag_filters={}):
|
||||
MAX_ITER = 50
|
||||
|
@ -371,7 +394,7 @@ class AutoscalingTest(unittest.TestCase):
|
|||
if comparison is None:
|
||||
comparison = self.assertEqual
|
||||
try:
|
||||
comparison(n, expected)
|
||||
comparison(n, expected, msg="Unexpected node quantity.")
|
||||
return
|
||||
except Exception:
|
||||
if i == MAX_ITER - 1:
|
||||
|
@ -2218,6 +2241,118 @@ MemAvailable: 33000000 kB
|
|||
|
||||
assert node == 1
|
||||
|
||||
def testNodeTerminatedDuringUpdate(self):
|
||||
"""
|
||||
Tests autoscaler handling a node getting terminated during an update
|
||||
triggered by the node missing a heartbeat.
|
||||
|
||||
Extension of testRecoverUnhealthyWorkers.
|
||||
|
||||
In this test, two nodes miss a heartbeat.
|
||||
One of them (node 0) is terminated during its recovery update.
|
||||
The other (node 1) just fails its update.
|
||||
|
||||
When processing completed updates, the autoscaler terminates node 1
|
||||
but does not try to terminate node 0 again.
|
||||
"""
|
||||
cluster_config = copy.deepcopy(MOCK_DEFAULT_CONFIG)
|
||||
cluster_config["available_node_types"]["ray.worker.default"][
|
||||
"min_workers"] = 2
|
||||
cluster_config["worker_start_ray_commands"] = ["ray_start_cmd"]
|
||||
|
||||
# Don't need the extra node type or a docker config.
|
||||
cluster_config["head_node_type"] = ["ray.worker.default"]
|
||||
del cluster_config["available_node_types"]["ray.head.default"]
|
||||
del cluster_config["docker"]
|
||||
|
||||
config_path = self.write_config(cluster_config)
|
||||
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
lm = LoadMetrics()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
|
||||
# Scale up to two up-to-date workers
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
self.provider.finish_starting_nodes()
|
||||
autoscaler.update()
|
||||
self.waitForNodes(
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||
|
||||
# Mark both nodes as unhealthy
|
||||
for _ in range(5):
|
||||
if autoscaler.updaters:
|
||||
time.sleep(0.05)
|
||||
autoscaler.update()
|
||||
assert not autoscaler.updaters
|
||||
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
|
||||
lm.last_heartbeat_time_by_ip["172.0.0.1"] = 0
|
||||
|
||||
# Set up process runner to terminate worker 0 during missed heartbeat
|
||||
# recovery and also cause the updater to fail.
|
||||
def terminate_worker_zero():
|
||||
self.provider.terminate_node(0)
|
||||
|
||||
autoscaler.process_runner = MockProcessRunner(
|
||||
fail_cmds=["ray_start_cmd"],
|
||||
cmd_to_callback={"ray_start_cmd": terminate_worker_zero})
|
||||
|
||||
num_calls = len(autoscaler.process_runner.calls)
|
||||
autoscaler.update()
|
||||
# Wait for updaters spawned by last autoscaler update to finish.
|
||||
self.waitFor(
|
||||
lambda: all(not updater.is_alive()
|
||||
for updater in autoscaler.updaters.values()),
|
||||
num_retries=500,
|
||||
fail_msg="Last round of updaters didn't complete on time."
|
||||
)
|
||||
# Check that updaters processed some commands in the last autoscaler
|
||||
# update.
|
||||
assert len(autoscaler.process_runner.calls) > num_calls,\
|
||||
"Did not get additional process runner calls on last autoscaler"\
|
||||
" update."
|
||||
# Missed heartbeat triggered recovery for both nodes.
|
||||
events = autoscaler.event_summarizer.summary()
|
||||
assert (
|
||||
"Restarting 2 nodes of type "
|
||||
"ray.worker.default (lost contact with raylet)." in events), events
|
||||
# Node 0 was terminated during the last update.
|
||||
# Node 1's updater failed, but node 1 won't be terminated until the
|
||||
# next autoscaler update.
|
||||
assert 0 not in autoscaler.workers(), "Node zero still non-terminated."
|
||||
assert not self.provider.is_terminated(1),\
|
||||
"Node one terminated prematurely."
|
||||
|
||||
autoscaler.update()
|
||||
# Failed updates processed are now processed.
|
||||
assert autoscaler.num_failed_updates[0] == 1,\
|
||||
"Node zero update failure not registered"
|
||||
assert autoscaler.num_failed_updates[1] == 1,\
|
||||
"Node one update failure not registered"
|
||||
# Completed-update-processing logic should have terminated node 1.
|
||||
assert self.provider.is_terminated(1), "Node 1 not terminated on time."
|
||||
|
||||
events = autoscaler.event_summarizer.summary()
|
||||
# Just one node (node_id 1) terminated in the last update.
|
||||
# Validates that we didn't try to double-terminate node 0.
|
||||
assert ("Removing 1 nodes of type "
|
||||
"ray.worker.default (launch failed)." in events), events
|
||||
# To be more explicit,
|
||||
assert ("Removing 2 nodes of type "
|
||||
"ray.worker.default (launch failed)." not in events), events
|
||||
|
||||
# Should get two new nodes after the next update.
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
assert set(autoscaler.workers()) == {2, 3},\
|
||||
"Unexpected node_ids"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
|
|
@ -12,6 +12,9 @@ import kubernetes
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from ray.autoscaler._private.kubernetes.node_provider import\
|
||||
KubernetesNodeProvider
|
||||
|
||||
IMAGE_ENV = "KUBERNETES_OPERATOR_TEST_IMAGE"
|
||||
IMAGE = os.getenv(IMAGE_ENV, "rayproject/ray:nightly")
|
||||
|
||||
|
@ -97,6 +100,14 @@ def get_operator_config_path(file_name):
|
|||
|
||||
class KubernetesOperatorTest(unittest.TestCase):
|
||||
def test_examples(self):
|
||||
|
||||
# Validate terminate_node error handling
|
||||
provider = KubernetesNodeProvider({
|
||||
"namespace": NAMESPACE
|
||||
}, "default_cluster_name")
|
||||
# 404 caught, no error
|
||||
provider.terminate_node("no-such-node")
|
||||
|
||||
with tempfile.NamedTemporaryFile("w+") as example_cluster_file, \
|
||||
tempfile.NamedTemporaryFile("w+") as example_cluster2_file,\
|
||||
tempfile.NamedTemporaryFile("w+") as operator_file,\
|
||||
|
|
Loading…
Add table
Reference in a new issue