[autoscaler][weekend nits] autoscaler.py type checking and other lint issues (#26646)

I run several linters, including mypy, in my local environment.
This is a PR of style nits for autoscaler.py meant to silence my linters.

This PR also adds a mypy check for autoscaler.py
This commit is contained in:
Dmitri Gekhtman 2022-07-18 13:27:19 -07:00 committed by GitHub
parent df421ad499
commit c4160ec34b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 16 deletions

View file

@ -138,6 +138,7 @@ MYPY_FILES=(
'autoscaler/sdk/__init__.py'
'autoscaler/sdk/sdk.py'
'autoscaler/_private/commands.py'
'autoscaler/_private/autoscaler.py'
# TODO(dmitri) Fails with meaningless error, maybe due to a bug in the mypy version
# in the CI. Type check once we get serious about type checking:
#'ray_operator/operator.py'

View file

@ -89,7 +89,7 @@ UpdateInstructions = namedtuple(
["node_id", "setup_commands", "ray_start_commands", "docker_config"],
)
NodeLaunchData = Tuple[NodeTypeConfigDict, NodeCount, NodeType]
NodeLaunchData = Tuple[NodeTypeConfigDict, NodeCount, Optional[NodeType]]
@dataclass
@ -226,9 +226,9 @@ class StandardAutoscaler:
self.event_summarizer = event_summarizer or EventSummarizer()
# Map from node_id to NodeUpdater threads
self.updaters = {}
self.num_failed_updates = defaultdict(int)
self.num_successful_updates = defaultdict(int)
self.updaters: Dict[NodeID, NodeUpdaterThread] = {}
self.num_failed_updates: Dict[NodeID, int] = defaultdict(int)
self.num_successful_updates: Dict[NodeID, int] = defaultdict(int)
self.num_failures = 0
self.last_update_time = 0.0
self.update_interval_s = update_interval_s
@ -332,6 +332,10 @@ class StandardAutoscaler:
raise e
def _update(self):
# For type checking, assert that these objects have been instantitiated.
assert self.provider
assert self.resource_demand_scheduler
now = time.time()
# Throttle autoscaling updates to this interval to avoid exceeding
# rate limits on API calls.
@ -406,6 +410,10 @@ class StandardAutoscaler:
Avoids terminating non-outdated nodes required by
autoscaler.sdk.request_resources().
"""
# For type checking, assert that these objects have been instantitiated.
assert self.non_terminated_nodes
assert self.provider
last_used = self.load_metrics.last_used_time_by_ip
horizon = now - (60 * self.config["idle_timeout_minutes"])
@ -427,6 +435,7 @@ class StandardAutoscaler:
node_type_counts = defaultdict(int)
def keep_node(node_id: NodeID) -> None:
assert self.provider
# Update per-type counts.
tags = self.provider.node_tags(node_id)
if TAG_RAY_USER_NODE_TYPE in tags:
@ -498,6 +507,9 @@ class StandardAutoscaler:
def schedule_node_termination(
self, node_id: NodeID, reason_opt: Optional[str], logger_method: Callable
) -> None:
# For type checking, assert that this object has been instantitiated.
assert self.provider
if reason_opt is None:
raise Exception("reason should be not None.")
reason: str = reason_opt
@ -520,6 +532,10 @@ class StandardAutoscaler:
def terminate_scheduled_nodes(self):
"""Terminate scheduled nodes and clean associated autoscaler state."""
# For type checking, assert that these objects have been instantitiated.
assert self.provider
assert self.non_terminated_nodes
if not self.nodes_to_terminate:
return
@ -545,6 +561,9 @@ class StandardAutoscaler:
the behavior may change to better reflect the name "Drain."
See https://github.com/ray-project/ray/pull/19350.
"""
# For type checking, assert that this object has been instantitiated.
assert self.provider
# The GCS expects Raylet ids in the request, rather than NodeProvider
# ids. To get the Raylet ids of the nodes to we're draining, we make
# the following translations of identifiers:
@ -733,6 +752,8 @@ class StandardAutoscaler:
unfulfilled: List of resource demands that would be unfulfilled
even after full scale-up.
"""
# For type checking, assert that this object has been instantitiated.
assert self.resource_demand_scheduler
pending = []
infeasible = []
for bundle in unfulfilled:
@ -778,6 +799,7 @@ class StandardAutoscaler:
least_recently_used = -1
def last_time_used(node_id: NodeID):
assert self.provider
node_ip = self.provider.internal_ip(node_id)
if node_ip not in last_used_copy:
return least_recently_used
@ -800,19 +822,23 @@ class StandardAutoscaler:
FrozenSet[NodeID]: a set of nodes (node ids) that
we should NOT terminate.
"""
# For type checking, assert that this object has been instantitiated.
assert self.provider
nodes_not_allowed_to_terminate: Set[NodeID] = set()
static_node_resources: Dict[
NodeIP, ResourceDict
] = self.load_metrics.get_static_node_resources_by_ip()
head_node_resources: ResourceDict = copy.deepcopy(
self.available_node_types[self.config["head_node_type"]]["resources"]
)
# TODO(ameer): this is somewhat duplicated in
# resource_demand_scheduler.py.
if not head_node_resources:
# Legacy yaml might include {} in the resources field.
# TODO(ameer): this is somewhat duplicated in
# resource_demand_scheduler.py.
static_nodes: Dict[
NodeIP, ResourceDict
] = self.load_metrics.get_static_node_resources_by_ip()
head_node_ip = self.provider.internal_ip(self.non_terminated_nodes.head_id)
head_node_resources = static_nodes.get(head_node_ip, {})
head_node_resources = static_node_resources.get(head_node_ip, {})
max_node_resources: List[ResourceDict] = [head_node_resources]
resource_demand_vector_worker_node_ids = []
@ -826,11 +852,8 @@ class StandardAutoscaler:
)
if not node_resources:
# Legacy yaml might include {} in the resources field.
static_nodes: Dict[
NodeIP, ResourceDict
] = self.load_metrics.get_static_node_resources_by_ip()
node_ip = self.provider.internal_ip(node_id)
node_resources = static_nodes.get(node_ip, {})
node_resources = static_node_resources.get(node_ip, {})
max_node_resources.append(node_resources)
resource_demand_vector_worker_node_ids.append(node_id)
# Since it is sorted based on last used, we "keep" nodes that are
@ -887,6 +910,9 @@ class StandardAutoscaler:
Optional[str]: reason for termination. Not None on
KeepOrTerminate.terminate, None otherwise.
"""
# For type checking, assert that this object has been instantitiated.
assert self.provider
tags = self.provider.node_tags(node_id)
if TAG_RAY_USER_NODE_TYPE in tags:
node_type = tags[TAG_RAY_USER_NODE_TYPE]
@ -1061,6 +1087,9 @@ class StandardAutoscaler:
"""Determine whether we've received a heartbeat from a node within the
last AUTOSCALER_HEARTBEAT_TIMEOUT_S seconds.
"""
# For type checking, assert that this object has been instantitiated.
assert self.provider
key = self.provider.internal_ip(node_id)
if key in self.load_metrics.last_heartbeat_time_by_ip:
@ -1074,6 +1103,10 @@ class StandardAutoscaler:
"""Terminated nodes for which we haven't received a heartbeat on time.
These nodes are subsequently terminated.
"""
# For type checking, assert that these objects have been instantitiated.
assert self.provider
assert self.non_terminated_nodes
for node_id in self.non_terminated_nodes.worker_ids:
node_status = self.provider.node_tags(node_id)[TAG_RAY_NODE_STATUS]
# We're not responsible for taking down
@ -1142,6 +1175,9 @@ class StandardAutoscaler:
self.updaters[node_id] = updater
def _get_node_type(self, node_id: str) -> str:
# For type checking, assert that this object has been instantitiated.
assert self.provider
node_tags = self.provider.node_tags(node_id)
if TAG_RAY_USER_NODE_TYPE in node_tags:
return node_tags[TAG_RAY_USER_NODE_TYPE]
@ -1149,6 +1185,9 @@ class StandardAutoscaler:
return "unknown_node_type"
def _get_node_type_specific_fields(self, node_id: str, fields_key: str) -> Any:
# For type checking, assert that this object has been instantitiated.
assert self.provider
fields = self.config[fields_key]
node_tags = self.provider.node_tags(node_id)
if TAG_RAY_USER_NODE_TYPE in node_tags:
@ -1298,9 +1337,12 @@ class StandardAutoscaler:
Returns:
AutoscalerSummary: The summary.
"""
# For type checking, assert that this object has been instantitiated.
assert self.provider
if not self.non_terminated_nodes:
return None
active_nodes = Counter()
active_nodes: Dict[NodeType, int] = Counter()
pending_nodes = []
failed_nodes = []
non_failed = set()

View file

@ -1,3 +1,5 @@
from typing import Optional
try:
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram
@ -5,7 +7,7 @@ try:
# The metrics in this class should be kept in sync with
# python/ray/tests/test_metrics_agent.py
class AutoscalerPrometheusMetrics:
def __init__(self, registry: CollectorRegistry = None):
def __init__(self, registry: Optional[CollectorRegistry] = None):
self.registry: CollectorRegistry = registry or CollectorRegistry(
auto_describe=True
)