mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
[autoscaler] Fix update/terminate race condition (#15019)
Co-authored-by: AmeerHajAli <ameerh@berkeley.edu>
This commit is contained in:
parent
3578d4e9d8
commit
42565d5bbe
4 changed files with 247 additions and 78 deletions
|
@ -250,16 +250,17 @@ class StandardAutoscaler:
|
||||||
for node_type, count in to_launch.items():
|
for node_type, count in to_launch.items():
|
||||||
self.launch_new_node(count, node_type=node_type)
|
self.launch_new_node(count, node_type=node_type)
|
||||||
|
|
||||||
|
if to_launch:
|
||||||
nodes = self.workers()
|
nodes = self.workers()
|
||||||
|
|
||||||
# Process any completed updates
|
# Process any completed updates
|
||||||
completed = []
|
completed_nodes = []
|
||||||
for node_id, updater in self.updaters.items():
|
for node_id, updater in self.updaters.items():
|
||||||
if not updater.is_alive():
|
if not updater.is_alive():
|
||||||
completed.append(node_id)
|
completed_nodes.append(node_id)
|
||||||
if completed:
|
if completed_nodes:
|
||||||
nodes_to_terminate: List[NodeID] = []
|
failed_nodes = []
|
||||||
for node_id in completed:
|
for node_id in completed_nodes:
|
||||||
if self.updaters[node_id].exitcode == 0:
|
if self.updaters[node_id].exitcode == 0:
|
||||||
self.num_successful_updates[node_id] += 1
|
self.num_successful_updates[node_id] += 1
|
||||||
# Mark the node as active to prevent the node recovery
|
# Mark the node as active to prevent the node recovery
|
||||||
|
@ -267,19 +268,34 @@ class StandardAutoscaler:
|
||||||
self.load_metrics.mark_active(
|
self.load_metrics.mark_active(
|
||||||
self.provider.internal_ip(node_id))
|
self.provider.internal_ip(node_id))
|
||||||
else:
|
else:
|
||||||
logger.error(f"StandardAutoscaler: {node_id}: Terminating "
|
failed_nodes.append(node_id)
|
||||||
"failed to setup/initialize node.")
|
self.num_failed_updates[node_id] += 1
|
||||||
|
self.node_tracker.untrack(node_id)
|
||||||
|
del self.updaters[node_id]
|
||||||
|
|
||||||
|
if failed_nodes:
|
||||||
|
# Some nodes in failed_nodes may have been terminated
|
||||||
|
# during an update (for being idle after missing a heartbeat).
|
||||||
|
# Only terminate currently non terminated nodes.
|
||||||
|
non_terminated_nodes = self.workers()
|
||||||
|
nodes_to_terminate: List[NodeID] = []
|
||||||
|
for node_id in failed_nodes:
|
||||||
|
if node_id in non_terminated_nodes:
|
||||||
|
nodes_to_terminate.append(node_id)
|
||||||
|
logger.error(f"StandardAutoscaler: {node_id}:"
|
||||||
|
" Terminating. Failed to setup/initialize"
|
||||||
|
" node.")
|
||||||
self.event_summarizer.add(
|
self.event_summarizer.add(
|
||||||
"Removing {} nodes of type " +
|
"Removing {} nodes of type " +
|
||||||
self._get_node_type(node_id) + " (launch failed).",
|
self._get_node_type(node_id) + " (launch failed).",
|
||||||
quantity=1,
|
quantity=1,
|
||||||
aggregate=operator.add)
|
aggregate=operator.add)
|
||||||
nodes_to_terminate.append(node_id)
|
else:
|
||||||
self.num_failed_updates[node_id] += 1
|
logger.warning(f"StandardAutoscaler: {node_id}:"
|
||||||
del self.updaters[node_id]
|
" Failed to update node."
|
||||||
|
" Node has already been terminated.")
|
||||||
if nodes_to_terminate:
|
if nodes_to_terminate:
|
||||||
self.provider.terminate_nodes(nodes_to_terminate)
|
self.provider.terminate_nodes(nodes_to_terminate)
|
||||||
|
|
||||||
nodes = self.workers()
|
nodes = self.workers()
|
||||||
|
|
||||||
# Update nodes with out-of-date files.
|
# Update nodes with out-of-date files.
|
||||||
|
@ -602,7 +618,7 @@ class StandardAutoscaler:
|
||||||
if TAG_RAY_USER_NODE_TYPE in node_tags:
|
if TAG_RAY_USER_NODE_TYPE in node_tags:
|
||||||
return node_tags[TAG_RAY_USER_NODE_TYPE]
|
return node_tags[TAG_RAY_USER_NODE_TYPE]
|
||||||
else:
|
else:
|
||||||
return "unknown"
|
return "unknown_node_type"
|
||||||
|
|
||||||
def _get_node_type_specific_fields(self, node_id: str,
|
def _get_node_type_specific_fields(self, node_id: str,
|
||||||
fields_key: str) -> Any:
|
fields_key: str) -> Any:
|
||||||
|
|
|
@ -152,7 +152,14 @@ class KubernetesNodeProvider(NodeProvider):
|
||||||
|
|
||||||
def terminate_node(self, node_id):
|
def terminate_node(self, node_id):
|
||||||
logger.info(log_prefix + "calling delete_namespaced_pod")
|
logger.info(log_prefix + "calling delete_namespaced_pod")
|
||||||
|
try:
|
||||||
core_api().delete_namespaced_pod(node_id, self.namespace)
|
core_api().delete_namespaced_pod(node_id, self.namespace)
|
||||||
|
except ApiException as e:
|
||||||
|
if e.status == 404:
|
||||||
|
logger.warning(log_prefix + f"Tried to delete pod {node_id},"
|
||||||
|
" but the pod was not found (404).")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
try:
|
try:
|
||||||
core_api().delete_namespaced_service(node_id, self.namespace)
|
core_api().delete_namespaced_service(node_id, self.namespace)
|
||||||
except ApiException:
|
except ApiException:
|
||||||
|
|
|
@ -11,6 +11,7 @@ from unittest.mock import Mock
|
||||||
import yaml
|
import yaml
|
||||||
import copy
|
import copy
|
||||||
from jsonschema.exceptions import ValidationError
|
from jsonschema.exceptions import ValidationError
|
||||||
|
from typing import Dict, Callable
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.autoscaler._private.util import prepare_config, validate_config
|
from ray.autoscaler._private.util import prepare_config, validate_config
|
||||||
|
@ -51,22 +52,39 @@ class MockNode:
|
||||||
|
|
||||||
|
|
||||||
class MockProcessRunner:
|
class MockProcessRunner:
|
||||||
def __init__(self, fail_cmds=None):
|
def __init__(self, fail_cmds=None, cmd_to_callback=None, print_out=False):
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
self.cmd_to_callback = cmd_to_callback or {
|
||||||
|
} # type: Dict[str, Callable]
|
||||||
|
self.print_out = print_out
|
||||||
self.fail_cmds = fail_cmds or []
|
self.fail_cmds = fail_cmds or []
|
||||||
self.call_response = {}
|
self.call_response = {}
|
||||||
self.ready_to_run = threading.Event()
|
self.ready_to_run = threading.Event()
|
||||||
self.ready_to_run.set()
|
self.ready_to_run.set()
|
||||||
|
|
||||||
|
self.lock = threading.RLock()
|
||||||
|
|
||||||
def check_call(self, cmd, *args, **kwargs):
|
def check_call(self, cmd, *args, **kwargs):
|
||||||
|
with self.lock:
|
||||||
self.ready_to_run.wait()
|
self.ready_to_run.wait()
|
||||||
|
self.calls.append(cmd)
|
||||||
|
if self.print_out:
|
||||||
|
print(f">>>Process runner: Executing \n {str(cmd)}")
|
||||||
|
for token in self.cmd_to_callback:
|
||||||
|
if token in str(cmd):
|
||||||
|
# Trigger a callback if token is in cmd.
|
||||||
|
# Can be used to simulate background events during a node
|
||||||
|
# update (e.g. node disconnected).
|
||||||
|
callback = self.cmd_to_callback[token]
|
||||||
|
callback()
|
||||||
|
|
||||||
for token in self.fail_cmds:
|
for token in self.fail_cmds:
|
||||||
if token in str(cmd):
|
if token in str(cmd):
|
||||||
raise CalledProcessError(1, token,
|
raise CalledProcessError(1, token,
|
||||||
"Failing command on purpose")
|
"Failing command on purpose")
|
||||||
self.calls.append(cmd)
|
|
||||||
|
|
||||||
def check_output(self, cmd):
|
def check_output(self, cmd):
|
||||||
|
with self.lock:
|
||||||
self.check_call(cmd)
|
self.check_call(cmd)
|
||||||
return_string = "command-output"
|
return_string = "command-output"
|
||||||
key_to_shrink = None
|
key_to_shrink = None
|
||||||
|
@ -84,6 +102,7 @@ class MockProcessRunner:
|
||||||
return return_string.encode()
|
return return_string.encode()
|
||||||
|
|
||||||
def assert_has_call(self, ip, pattern=None, exact=None):
|
def assert_has_call(self, ip, pattern=None, exact=None):
|
||||||
|
with self.lock:
|
||||||
assert pattern or exact, \
|
assert pattern or exact, \
|
||||||
"Must specify either a pattern or exact match."
|
"Must specify either a pattern or exact match."
|
||||||
out = ""
|
out = ""
|
||||||
|
@ -111,6 +130,7 @@ class MockProcessRunner:
|
||||||
f"\n\nFull output: {self.command_history()}")
|
f"\n\nFull output: {self.command_history()}")
|
||||||
|
|
||||||
def assert_not_has_call(self, ip, pattern):
|
def assert_not_has_call(self, ip, pattern):
|
||||||
|
with self.lock:
|
||||||
out = ""
|
out = ""
|
||||||
for cmd in self.command_history():
|
for cmd in self.command_history():
|
||||||
if ip in cmd:
|
if ip in cmd:
|
||||||
|
@ -123,12 +143,15 @@ class MockProcessRunner:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def clear_history(self):
|
def clear_history(self):
|
||||||
|
with self.lock:
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
|
||||||
def command_history(self):
|
def command_history(self):
|
||||||
|
with self.lock:
|
||||||
return [" ".join(cmd) for cmd in self.calls]
|
return [" ".join(cmd) for cmd in self.calls]
|
||||||
|
|
||||||
def respond_to_call(self, pattern, response_list):
|
def respond_to_call(self, pattern, response_list):
|
||||||
|
with self.lock:
|
||||||
self.call_response[pattern] = response_list
|
self.call_response[pattern] = response_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -356,13 +379,13 @@ class AutoscalingTest(unittest.TestCase):
|
||||||
shutil.rmtree(self.tmpdir)
|
shutil.rmtree(self.tmpdir)
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def waitFor(self, condition, num_retries=50):
|
def waitFor(self, condition, num_retries=50, fail_msg=None):
|
||||||
for _ in range(num_retries):
|
for _ in range(num_retries):
|
||||||
if condition():
|
if condition():
|
||||||
return
|
return
|
||||||
time.sleep(.1)
|
time.sleep(.1)
|
||||||
raise RayTestTimeoutException(
|
fail_msg = fail_msg or "Timed out waiting for {}".format(condition)
|
||||||
"Timed out waiting for {}".format(condition))
|
raise RayTestTimeoutException(fail_msg)
|
||||||
|
|
||||||
def waitForNodes(self, expected, comparison=None, tag_filters={}):
|
def waitForNodes(self, expected, comparison=None, tag_filters={}):
|
||||||
MAX_ITER = 50
|
MAX_ITER = 50
|
||||||
|
@ -371,7 +394,7 @@ class AutoscalingTest(unittest.TestCase):
|
||||||
if comparison is None:
|
if comparison is None:
|
||||||
comparison = self.assertEqual
|
comparison = self.assertEqual
|
||||||
try:
|
try:
|
||||||
comparison(n, expected)
|
comparison(n, expected, msg="Unexpected node quantity.")
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
if i == MAX_ITER - 1:
|
if i == MAX_ITER - 1:
|
||||||
|
@ -2218,6 +2241,118 @@ MemAvailable: 33000000 kB
|
||||||
|
|
||||||
assert node == 1
|
assert node == 1
|
||||||
|
|
||||||
|
def testNodeTerminatedDuringUpdate(self):
|
||||||
|
"""
|
||||||
|
Tests autoscaler handling a node getting terminated during an update
|
||||||
|
triggered by the node missing a heartbeat.
|
||||||
|
|
||||||
|
Extension of testRecoverUnhealthyWorkers.
|
||||||
|
|
||||||
|
In this test, two nodes miss a heartbeat.
|
||||||
|
One of them (node 0) is terminated during its recovery update.
|
||||||
|
The other (node 1) just fails its update.
|
||||||
|
|
||||||
|
When processing completed updates, the autoscaler terminates node 1
|
||||||
|
but does not try to terminate node 0 again.
|
||||||
|
"""
|
||||||
|
cluster_config = copy.deepcopy(MOCK_DEFAULT_CONFIG)
|
||||||
|
cluster_config["available_node_types"]["ray.worker.default"][
|
||||||
|
"min_workers"] = 2
|
||||||
|
cluster_config["worker_start_ray_commands"] = ["ray_start_cmd"]
|
||||||
|
|
||||||
|
# Don't need the extra node type or a docker config.
|
||||||
|
cluster_config["head_node_type"] = ["ray.worker.default"]
|
||||||
|
del cluster_config["available_node_types"]["ray.head.default"]
|
||||||
|
del cluster_config["docker"]
|
||||||
|
|
||||||
|
config_path = self.write_config(cluster_config)
|
||||||
|
|
||||||
|
self.provider = MockProvider()
|
||||||
|
runner = MockProcessRunner()
|
||||||
|
lm = LoadMetrics()
|
||||||
|
autoscaler = StandardAutoscaler(
|
||||||
|
config_path,
|
||||||
|
lm,
|
||||||
|
max_failures=0,
|
||||||
|
process_runner=runner,
|
||||||
|
update_interval_s=0)
|
||||||
|
|
||||||
|
# Scale up to two up-to-date workers
|
||||||
|
autoscaler.update()
|
||||||
|
self.waitForNodes(2)
|
||||||
|
self.provider.finish_starting_nodes()
|
||||||
|
autoscaler.update()
|
||||||
|
self.waitForNodes(
|
||||||
|
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE})
|
||||||
|
|
||||||
|
# Mark both nodes as unhealthy
|
||||||
|
for _ in range(5):
|
||||||
|
if autoscaler.updaters:
|
||||||
|
time.sleep(0.05)
|
||||||
|
autoscaler.update()
|
||||||
|
assert not autoscaler.updaters
|
||||||
|
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
|
||||||
|
lm.last_heartbeat_time_by_ip["172.0.0.1"] = 0
|
||||||
|
|
||||||
|
# Set up process runner to terminate worker 0 during missed heartbeat
|
||||||
|
# recovery and also cause the updater to fail.
|
||||||
|
def terminate_worker_zero():
|
||||||
|
self.provider.terminate_node(0)
|
||||||
|
|
||||||
|
autoscaler.process_runner = MockProcessRunner(
|
||||||
|
fail_cmds=["ray_start_cmd"],
|
||||||
|
cmd_to_callback={"ray_start_cmd": terminate_worker_zero})
|
||||||
|
|
||||||
|
num_calls = len(autoscaler.process_runner.calls)
|
||||||
|
autoscaler.update()
|
||||||
|
# Wait for updaters spawned by last autoscaler update to finish.
|
||||||
|
self.waitFor(
|
||||||
|
lambda: all(not updater.is_alive()
|
||||||
|
for updater in autoscaler.updaters.values()),
|
||||||
|
num_retries=500,
|
||||||
|
fail_msg="Last round of updaters didn't complete on time."
|
||||||
|
)
|
||||||
|
# Check that updaters processed some commands in the last autoscaler
|
||||||
|
# update.
|
||||||
|
assert len(autoscaler.process_runner.calls) > num_calls,\
|
||||||
|
"Did not get additional process runner calls on last autoscaler"\
|
||||||
|
" update."
|
||||||
|
# Missed heartbeat triggered recovery for both nodes.
|
||||||
|
events = autoscaler.event_summarizer.summary()
|
||||||
|
assert (
|
||||||
|
"Restarting 2 nodes of type "
|
||||||
|
"ray.worker.default (lost contact with raylet)." in events), events
|
||||||
|
# Node 0 was terminated during the last update.
|
||||||
|
# Node 1's updater failed, but node 1 won't be terminated until the
|
||||||
|
# next autoscaler update.
|
||||||
|
assert 0 not in autoscaler.workers(), "Node zero still non-terminated."
|
||||||
|
assert not self.provider.is_terminated(1),\
|
||||||
|
"Node one terminated prematurely."
|
||||||
|
|
||||||
|
autoscaler.update()
|
||||||
|
# Failed updates processed are now processed.
|
||||||
|
assert autoscaler.num_failed_updates[0] == 1,\
|
||||||
|
"Node zero update failure not registered"
|
||||||
|
assert autoscaler.num_failed_updates[1] == 1,\
|
||||||
|
"Node one update failure not registered"
|
||||||
|
# Completed-update-processing logic should have terminated node 1.
|
||||||
|
assert self.provider.is_terminated(1), "Node 1 not terminated on time."
|
||||||
|
|
||||||
|
events = autoscaler.event_summarizer.summary()
|
||||||
|
# Just one node (node_id 1) terminated in the last update.
|
||||||
|
# Validates that we didn't try to double-terminate node 0.
|
||||||
|
assert ("Removing 1 nodes of type "
|
||||||
|
"ray.worker.default (launch failed)." in events), events
|
||||||
|
# To be more explicit,
|
||||||
|
assert ("Removing 2 nodes of type "
|
||||||
|
"ray.worker.default (launch failed)." not in events), events
|
||||||
|
|
||||||
|
# Should get two new nodes after the next update.
|
||||||
|
autoscaler.update()
|
||||||
|
self.waitForNodes(2)
|
||||||
|
assert set(autoscaler.workers()) == {2, 3},\
|
||||||
|
"Unexpected node_ids"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -12,6 +12,9 @@ import kubernetes
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from ray.autoscaler._private.kubernetes.node_provider import\
|
||||||
|
KubernetesNodeProvider
|
||||||
|
|
||||||
IMAGE_ENV = "KUBERNETES_OPERATOR_TEST_IMAGE"
|
IMAGE_ENV = "KUBERNETES_OPERATOR_TEST_IMAGE"
|
||||||
IMAGE = os.getenv(IMAGE_ENV, "rayproject/ray:nightly")
|
IMAGE = os.getenv(IMAGE_ENV, "rayproject/ray:nightly")
|
||||||
|
|
||||||
|
@ -97,6 +100,14 @@ def get_operator_config_path(file_name):
|
||||||
|
|
||||||
class KubernetesOperatorTest(unittest.TestCase):
|
class KubernetesOperatorTest(unittest.TestCase):
|
||||||
def test_examples(self):
|
def test_examples(self):
|
||||||
|
|
||||||
|
# Validate terminate_node error handling
|
||||||
|
provider = KubernetesNodeProvider({
|
||||||
|
"namespace": NAMESPACE
|
||||||
|
}, "default_cluster_name")
|
||||||
|
# 404 caught, no error
|
||||||
|
provider.terminate_node("no-such-node")
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile("w+") as example_cluster_file, \
|
with tempfile.NamedTemporaryFile("w+") as example_cluster_file, \
|
||||||
tempfile.NamedTemporaryFile("w+") as example_cluster2_file,\
|
tempfile.NamedTemporaryFile("w+") as example_cluster2_file,\
|
||||||
tempfile.NamedTemporaryFile("w+") as operator_file,\
|
tempfile.NamedTemporaryFile("w+") as operator_file,\
|
||||||
|
|
Loading…
Add table
Reference in a new issue