diff --git a/ci/lint/format.sh b/ci/lint/format.sh index 7fc2d9fc7..d064a7bda 100755 --- a/ci/lint/format.sh +++ b/ci/lint/format.sh @@ -138,6 +138,7 @@ MYPY_FILES=( 'autoscaler/sdk/__init__.py' 'autoscaler/sdk/sdk.py' 'autoscaler/_private/commands.py' + 'autoscaler/_private/autoscaler.py' # TODO(dmitri) Fails with meaningless error, maybe due to a bug in the mypy version # in the CI. Type check once we get serious about type checking: #'ray_operator/operator.py' diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index 5cec0162e..022e9a067 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -89,7 +89,7 @@ UpdateInstructions = namedtuple( ["node_id", "setup_commands", "ray_start_commands", "docker_config"], ) -NodeLaunchData = Tuple[NodeTypeConfigDict, NodeCount, NodeType] +NodeLaunchData = Tuple[NodeTypeConfigDict, NodeCount, Optional[NodeType]] @dataclass @@ -226,9 +226,9 @@ class StandardAutoscaler: self.event_summarizer = event_summarizer or EventSummarizer() # Map from node_id to NodeUpdater threads - self.updaters = {} - self.num_failed_updates = defaultdict(int) - self.num_successful_updates = defaultdict(int) + self.updaters: Dict[NodeID, NodeUpdaterThread] = {} + self.num_failed_updates: Dict[NodeID, int] = defaultdict(int) + self.num_successful_updates: Dict[NodeID, int] = defaultdict(int) self.num_failures = 0 self.last_update_time = 0.0 self.update_interval_s = update_interval_s @@ -332,6 +332,10 @@ class StandardAutoscaler: raise e def _update(self): + # For type checking, assert that these objects have been instantitiated. + assert self.provider + assert self.resource_demand_scheduler + now = time.time() # Throttle autoscaling updates to this interval to avoid exceeding # rate limits on API calls. @@ -406,6 +410,10 @@ class StandardAutoscaler: Avoids terminating non-outdated nodes required by autoscaler.sdk.request_resources(). """ + # For type checking, assert that these objects have been instantitiated. + assert self.non_terminated_nodes + assert self.provider + last_used = self.load_metrics.last_used_time_by_ip horizon = now - (60 * self.config["idle_timeout_minutes"]) @@ -427,6 +435,7 @@ class StandardAutoscaler: node_type_counts = defaultdict(int) def keep_node(node_id: NodeID) -> None: + assert self.provider # Update per-type counts. tags = self.provider.node_tags(node_id) if TAG_RAY_USER_NODE_TYPE in tags: @@ -498,6 +507,9 @@ class StandardAutoscaler: def schedule_node_termination( self, node_id: NodeID, reason_opt: Optional[str], logger_method: Callable ) -> None: + # For type checking, assert that this object has been instantitiated. + assert self.provider + if reason_opt is None: raise Exception("reason should be not None.") reason: str = reason_opt @@ -520,6 +532,10 @@ class StandardAutoscaler: def terminate_scheduled_nodes(self): """Terminate scheduled nodes and clean associated autoscaler state.""" + # For type checking, assert that these objects have been instantitiated. + assert self.provider + assert self.non_terminated_nodes + if not self.nodes_to_terminate: return @@ -545,6 +561,9 @@ class StandardAutoscaler: the behavior may change to better reflect the name "Drain." See https://github.com/ray-project/ray/pull/19350. """ + # For type checking, assert that this object has been instantitiated. + assert self.provider + # The GCS expects Raylet ids in the request, rather than NodeProvider # ids. To get the Raylet ids of the nodes to we're draining, we make # the following translations of identifiers: @@ -733,6 +752,8 @@ class StandardAutoscaler: unfulfilled: List of resource demands that would be unfulfilled even after full scale-up. """ + # For type checking, assert that this object has been instantitiated. + assert self.resource_demand_scheduler pending = [] infeasible = [] for bundle in unfulfilled: @@ -778,6 +799,7 @@ class StandardAutoscaler: least_recently_used = -1 def last_time_used(node_id: NodeID): + assert self.provider node_ip = self.provider.internal_ip(node_id) if node_ip not in last_used_copy: return least_recently_used @@ -800,19 +822,23 @@ class StandardAutoscaler: FrozenSet[NodeID]: a set of nodes (node ids) that we should NOT terminate. """ + # For type checking, assert that this object has been instantitiated. + assert self.provider + nodes_not_allowed_to_terminate: Set[NodeID] = set() + static_node_resources: Dict[ + NodeIP, ResourceDict + ] = self.load_metrics.get_static_node_resources_by_ip() + head_node_resources: ResourceDict = copy.deepcopy( self.available_node_types[self.config["head_node_type"]]["resources"] ) + # TODO(ameer): this is somewhat duplicated in + # resource_demand_scheduler.py. if not head_node_resources: # Legacy yaml might include {} in the resources field. - # TODO(ameer): this is somewhat duplicated in - # resource_demand_scheduler.py. - static_nodes: Dict[ - NodeIP, ResourceDict - ] = self.load_metrics.get_static_node_resources_by_ip() head_node_ip = self.provider.internal_ip(self.non_terminated_nodes.head_id) - head_node_resources = static_nodes.get(head_node_ip, {}) + head_node_resources = static_node_resources.get(head_node_ip, {}) max_node_resources: List[ResourceDict] = [head_node_resources] resource_demand_vector_worker_node_ids = [] @@ -826,11 +852,8 @@ class StandardAutoscaler: ) if not node_resources: # Legacy yaml might include {} in the resources field. - static_nodes: Dict[ - NodeIP, ResourceDict - ] = self.load_metrics.get_static_node_resources_by_ip() node_ip = self.provider.internal_ip(node_id) - node_resources = static_nodes.get(node_ip, {}) + node_resources = static_node_resources.get(node_ip, {}) max_node_resources.append(node_resources) resource_demand_vector_worker_node_ids.append(node_id) # Since it is sorted based on last used, we "keep" nodes that are @@ -887,6 +910,9 @@ class StandardAutoscaler: Optional[str]: reason for termination. Not None on KeepOrTerminate.terminate, None otherwise. """ + # For type checking, assert that this object has been instantitiated. + assert self.provider + tags = self.provider.node_tags(node_id) if TAG_RAY_USER_NODE_TYPE in tags: node_type = tags[TAG_RAY_USER_NODE_TYPE] @@ -1061,6 +1087,9 @@ class StandardAutoscaler: """Determine whether we've received a heartbeat from a node within the last AUTOSCALER_HEARTBEAT_TIMEOUT_S seconds. """ + # For type checking, assert that this object has been instantitiated. + assert self.provider + key = self.provider.internal_ip(node_id) if key in self.load_metrics.last_heartbeat_time_by_ip: @@ -1074,6 +1103,10 @@ class StandardAutoscaler: """Terminated nodes for which we haven't received a heartbeat on time. These nodes are subsequently terminated. """ + # For type checking, assert that these objects have been instantitiated. + assert self.provider + assert self.non_terminated_nodes + for node_id in self.non_terminated_nodes.worker_ids: node_status = self.provider.node_tags(node_id)[TAG_RAY_NODE_STATUS] # We're not responsible for taking down @@ -1142,6 +1175,9 @@ class StandardAutoscaler: self.updaters[node_id] = updater def _get_node_type(self, node_id: str) -> str: + # For type checking, assert that this object has been instantitiated. + assert self.provider + node_tags = self.provider.node_tags(node_id) if TAG_RAY_USER_NODE_TYPE in node_tags: return node_tags[TAG_RAY_USER_NODE_TYPE] @@ -1149,6 +1185,9 @@ class StandardAutoscaler: return "unknown_node_type" def _get_node_type_specific_fields(self, node_id: str, fields_key: str) -> Any: + # For type checking, assert that this object has been instantitiated. + assert self.provider + fields = self.config[fields_key] node_tags = self.provider.node_tags(node_id) if TAG_RAY_USER_NODE_TYPE in node_tags: @@ -1298,9 +1337,12 @@ class StandardAutoscaler: Returns: AutoscalerSummary: The summary. """ + # For type checking, assert that this object has been instantitiated. + assert self.provider + if not self.non_terminated_nodes: return None - active_nodes = Counter() + active_nodes: Dict[NodeType, int] = Counter() pending_nodes = [] failed_nodes = [] non_failed = set() diff --git a/python/ray/autoscaler/_private/prom_metrics.py b/python/ray/autoscaler/_private/prom_metrics.py index e2fc54c06..2d7e79e9f 100644 --- a/python/ray/autoscaler/_private/prom_metrics.py +++ b/python/ray/autoscaler/_private/prom_metrics.py @@ -1,3 +1,5 @@ +from typing import Optional + try: from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram @@ -5,7 +7,7 @@ try: # The metrics in this class should be kept in sync with # python/ray/tests/test_metrics_agent.py class AutoscalerPrometheusMetrics: - def __init__(self, registry: CollectorRegistry = None): + def __init__(self, registry: Optional[CollectorRegistry] = None): self.registry: CollectorRegistry = registry or CollectorRegistry( auto_describe=True )