diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index 8321e153a..71d6e2012 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -15,8 +15,7 @@ import collections from ray.autoscaler.tags import ( TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG, TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH, - STATUS_SYNCING_FILES, STATUS_SETTING_UP, STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED, NODE_KIND_WORKER, NODE_KIND_UNMANAGED, NODE_KIND_HEAD) from ray.autoscaler._private.event_summarizer import EventSummarizer from ray.autoscaler._private.legacy_info_string import legacy_log_info_string @@ -122,7 +121,7 @@ class StandardAutoscaler: self.process_runner = process_runner self.event_summarizer = event_summarizer or EventSummarizer() - # Map from node_id to NodeUpdater processes + # Map from node_id to NodeUpdater threads self.updaters = {} self.num_failed_updates = defaultdict(int) self.num_successful_updates = defaultdict(int) @@ -130,6 +129,12 @@ class StandardAutoscaler: self.last_update_time = 0.0 self.update_interval_s = update_interval_s + # Disable NodeUpdater threads if true. + # Should be set to true in situations where another component, such as + # a Kubernetes operator, is responsible for Ray setup on nodes. + self.disable_node_updaters = self.config["provider"].get( + "disable_node_updaters", False) + # Node launchers self.launch_queue = queue.Queue() self.pending_launches = ConcurrentCounter() @@ -245,10 +250,7 @@ class StandardAutoscaler: nodes_to_terminate.append(node_id) if nodes_to_terminate: - self.provider.terminate_nodes(nodes_to_terminate) - for node in nodes_to_terminate: - self.node_tracker.untrack(node) - self.prom_metrics.stopped_nodes.inc() + self._terminate_nodes_and_cleanup(nodes_to_terminate) nodes = self.workers() # Terminate nodes if there are too many @@ -266,10 +268,7 @@ class StandardAutoscaler: nodes_to_terminate.append(to_terminate) if nodes_to_terminate: - self.provider.terminate_nodes(nodes_to_terminate) - for node in nodes_to_terminate: - self.node_tracker.untrack(node) - self.prom_metrics.stopped_nodes.inc() + self._terminate_nodes_and_cleanup(nodes_to_terminate) nodes = self.workers() to_launch = self.resource_demand_scheduler.get_nodes_to_launch( @@ -338,9 +337,7 @@ class StandardAutoscaler: " Failed to update node." " Node has already been terminated.") if nodes_to_terminate: - self.prom_metrics.stopped_nodes.inc( - len(nodes_to_terminate)) - self.provider.terminate_nodes(nodes_to_terminate) + self._terminate_nodes_and_cleanup(nodes_to_terminate) nodes = self.workers() # Update nodes with out-of-date files. @@ -363,9 +360,13 @@ class StandardAutoscaler: for t in T: t.join() - # Attempt to recover unhealthy nodes - for node_id in nodes: - self.recover_if_needed(node_id, now) + if self.disable_node_updaters: + # If updaters are unavailable, terminate unhealthy nodes. + self.terminate_unhealthy_nodes(nodes, now) + else: + # Attempt to recover unhealthy nodes + for node_id in nodes: + self.recover_if_needed(node_id, now) self.prom_metrics.updating_nodes.set(len(self.updaters)) num_recovering = 0 @@ -376,6 +377,13 @@ class StandardAutoscaler: logger.info(self.info_string()) legacy_log_info_string(self, nodes) + def _terminate_nodes_and_cleanup(self, nodes_to_terminate: List[str]): + """Terminate specified nodes and clean associated autoscaler state.""" + self.provider.terminate_nodes(nodes_to_terminate) + for node in nodes_to_terminate: + self.node_tracker.untrack(node) + self.prom_metrics.stopped_nodes.inc() + def _sort_based_on_last_used(self, nodes: List[NodeID], last_used: Dict[str, float]) -> List[NodeID]: """Sort the nodes based on the last time they were used. @@ -647,9 +655,10 @@ class StandardAutoscaler: return False return True - def recover_if_needed(self, node_id, now): - if not self.can_update(node_id): - return + def heartbeat_on_time(self, node_id: NodeID, now: float) -> bool: + """Determine whether we've received a heartbeat from a node within the + last AUTOSCALER_HEARTBEAT_TIMEOUT_S seconds. + """ key = self.provider.internal_ip(node_id) if key in self.load_metrics.last_heartbeat_time_by_ip: @@ -657,7 +666,43 @@ class StandardAutoscaler: key] delta = now - last_heartbeat_time if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: - return + return True + return False + + def terminate_unhealthy_nodes(self, nodes: List[NodeID], now: float): + """Terminate nodes for which we haven't received a heartbeat on time. + + Used when node updaters are not available for recovery. + """ + nodes_to_terminate = [] + for node_id in nodes: + node_status = self.provider.node_tags(node_id)[TAG_RAY_NODE_STATUS] + # We're not responsible for taking down + # nodes with pending or failed status: + if not node_status == STATUS_UP_TO_DATE: + continue + # Heartbeat indicates node is healthy: + if self.heartbeat_on_time(node_id, now): + continue + # Node is unhealthy, terminate: + logger.warning("StandardAutoscaler: " + "{}: No recent heartbeat, " + "terminating node.".format(node_id)) + self.event_summarizer.add( + "Terminating {} nodes of type " + self._get_node_type(node_id) + + " (lost contact with raylet).", + quantity=1, + aggregate=operator.add) + nodes_to_terminate.append(node_id) + + if nodes_to_terminate: + self._terminate_nodes_and_cleanup(nodes_to_terminate) + + def recover_if_needed(self, node_id, now): + if not self.can_update(node_id): + return + if self.heartbeat_on_time(node_id, now): + return logger.warning("StandardAutoscaler: " "{}: No recent heartbeat, " @@ -783,6 +828,8 @@ class StandardAutoscaler: self.updaters[node_id] = updater def can_update(self, node_id): + if self.disable_node_updaters: + return False if node_id in self.updaters: return False if not self.launch_config_ok(node_id): @@ -876,11 +923,8 @@ class StandardAutoscaler: non_failed.add(node_id) else: status = node_tags[TAG_RAY_NODE_STATUS] - pending_states = [ - STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH, - STATUS_SYNCING_FILES, STATUS_SETTING_UP - ] - is_pending = status in pending_states + completed_states = [STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED] + is_pending = status not in completed_states if is_pending: pending_nodes.append((ip, node_type, status)) non_failed.add(node_id) diff --git a/python/ray/autoscaler/ray-schema.json b/python/ray/autoscaler/ray-schema.json index 3cda16514..64e67c7c8 100644 --- a/python/ray/autoscaler/ray-schema.json +++ b/python/ray/autoscaler/ray-schema.json @@ -160,6 +160,10 @@ } } }, + "disable_node_updaters": { + "type": "boolean", + "description": "Disables node updaters if set to True. Default is False. (For Kubernetes operator usage.)" + }, "gcp_credentials": { "type": "object", "description": "Credentials for authenticating with the GCP client", diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index b95a2edfc..6ae444cdb 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -1002,8 +1002,10 @@ class AutoscalingTest(unittest.TestCase): runner.assert_has_call("172.0.0.4", pattern="rsync") runner.clear_history() - def testScaleUp(self): - config_path = self.write_config(SMALL_CLUSTER) + def ScaleUpHelper(self, disable_node_updaters): + config = copy.deepcopy(SMALL_CLUSTER) + config["provider"]["disable_node_updaters"] = disable_node_updaters + config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() mock_metrics = Mock(spec=AutoscalerPrometheusMetrics()) @@ -1022,13 +1024,33 @@ class AutoscalingTest(unittest.TestCase): assert mock_metrics.started_nodes.inc.call_count == 1 mock_metrics.started_nodes.inc.assert_called_with(2) assert mock_metrics.worker_create_node_time.observe.call_count == 2 - autoscaler.update() self.waitForNodes(2) # running_workers metric should be set to 2 mock_metrics.running_workers.set.assert_called_with(2) + if disable_node_updaters: + # Node Updaters have NOT been invoked because they were explicitly + # disabled. + time.sleep(1) + assert len(runner.calls) == 0 + # Nodes were create in uninitialized and not updated. + self.waitForNodes( + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UNINITIALIZED}) + else: + # Node Updaters have been invoked. + self.waitFor(lambda: len(runner.calls) > 0) + # The updates failed. Key thing is that the updates completed. + self.waitForNodes( + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED}) + + def testScaleUp(self): + self.ScaleUpHelper(disable_node_updaters=False) + + def testScaleUpNoUpdaters(self): + self.ScaleUpHelper(disable_node_updaters=True) + def testTerminateOutdatedNodesGracefully(self): config = SMALL_CLUSTER.copy() config["min_workers"] = 5 @@ -1889,6 +1911,62 @@ class AutoscalingTest(unittest.TestCase): "ray-legacy-worker-node-type (lost contact with raylet)." in events), events + def testTerminateUnhealthyWorkers(self): + """Test termination of unhealthy workers, when + autoscaler.disable_node_updaters == True. + + Similar to testRecoverUnhealthyWorkers. + """ + config_path = self.write_config(SMALL_CLUSTER) + self.provider = MockProvider() + runner = MockProcessRunner() + runner.respond_to_call("json .Config.Env", ["[]" for i in range(3)]) + lm = LoadMetrics() + mock_metrics = Mock(spec=AutoscalerPrometheusMetrics()) + autoscaler = StandardAutoscaler( + config_path, + lm, + max_failures=0, + process_runner=runner, + update_interval_s=0, + prom_metrics=mock_metrics) + autoscaler.update() + self.waitForNodes(2) + self.provider.finish_starting_nodes() + autoscaler.update() + self.waitForNodes( + 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE}) + + # Mark a node as unhealthy + for _ in range(5): + if autoscaler.updaters: + time.sleep(0.05) + autoscaler.update() + assert not autoscaler.updaters + num_calls = len(runner.calls) + lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0 + # Turn off updaters. + autoscaler.disable_node_updaters = True + # Reduce min_workers to 1 + autoscaler.config["available_node_types"][NODE_TYPE_LEGACY_WORKER][ + "min_workers"] = 1 + autoscaler.update() + # Stopped node metric incremented. + mock_metrics.stopped_nodes.inc.assert_called_once_with() + # One node left. + self.waitForNodes(1) + + # Check the node removal event is generated. + autoscaler.update() + events = autoscaler.event_summarizer.summary() + assert ("Terminating 1 nodes of type " + "ray-legacy-worker-node-type (lost contact with raylet)." in + events), events + + # No additional runner calls, since updaters were disabled. + time.sleep(1) + assert len(runner.calls) == num_calls + def testExternalNodeScaler(self): config = SMALL_CLUSTER.copy() config["provider"] = {