[AWS] Abort if AZs & SubnetIds mismatch (#22001)

If a user simultaneously selects AZs to use & specifies Subnets not in those AZs, raise an error!
This commit is contained in:
Ian Rodney 2022-04-21 11:07:59 -07:00 committed by GitHub
parent 02b0d82cf8
commit 0c16bbd245
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 54 deletions

View file

@ -6,7 +6,7 @@ import itertools
import json import json
import os import os
import time import time
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Set, Tuple
import logging import logging
import boto3 import boto3
@ -434,23 +434,43 @@ def _key_assert_msg(node_type: str) -> str:
) )
def _configure_subnet(config): def _usable_subnet_ids(
ec2 = _resource("ec2", config) user_specified_subnets: Optional[List[Any]],
use_internal_ips = config["provider"].get("use_internal_ips", False) all_subnets: List[Any],
azs: Optional[str],
vpc_id_of_sg: Optional[str],
use_internal_ips: bool,
node_type_key: str,
) -> Tuple[List[str], str]:
"""Prunes subnets down to those that meet the following criteria.
# If head or worker security group is specified, filter down to subnets Subnets must be:
# belonging to the same VPC as the security group. * 'Available' according to AWS.
sg_ids = [] * Public, unless `use_internal_ips` is specified.
for node_type in config["available_node_types"].values(): * In one of the AZs, if AZs are provided.
node_config = node_type["node_config"] * In the given VPC, if a VPC is specified for Security Groups.
sg_ids.extend(node_config.get("SecurityGroupIds", []))
if sg_ids: Returns:
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config) List[str]: Subnets that are usable.
else: str: VPC ID of the first subnet.
vpc_id_of_sg = None """
def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
return user_specified_subnets is not None and len(current_subnets) != len(
user_specified_subnets
)
def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
current_subnet_ids = {s.subnet_id for s in current_subnets}
user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets}
return user_specified_subnet_ids - current_subnet_ids
try: try:
candidate_subnets = ec2.subnets.all() candidate_subnets = (
user_specified_subnets
if user_specified_subnets is not None
else all_subnets
)
if vpc_id_of_sg: if vpc_id_of_sg:
candidate_subnets = [ candidate_subnets = [
s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
@ -471,16 +491,21 @@ def _configure_subnet(config):
if not subnets: if not subnets:
cli_logger.abort( cli_logger.abort(
"No usable subnets found, try manually creating an instance in " f"No usable subnets found for node type {node_type_key}, try "
"your specified region to populate the list of subnets " "manually creating an instance in your specified region to "
"and trying this again.\n" "populate the list of subnets and trying this again.\n"
"Note that the subnet must map public IPs " "Note that the subnet must map public IPs "
"on instance launch unless you set `use_internal_ips: true` in " "on instance launch unless you set `use_internal_ips: true` in "
"the `provider` config." "the `provider` config."
) )
elif _are_user_subnets_pruned(subnets):
cli_logger.abort(
f"The specified subnets for node type {node_type_key} are not "
f"usable: {_get_pruned_subnets(subnets)}"
)
if "availability_zone" in config["provider"]: if azs is not None:
azs = config["provider"]["availability_zone"].split(",") azs = [az.strip() for az in azs.split(",")]
subnets = [ subnets = [
s s
for az in azs # Iterate over AZs first to maintain the ordering for az in azs # Iterate over AZs first to maintain the ordering
@ -489,11 +514,19 @@ def _configure_subnet(config):
] ]
if not subnets: if not subnets:
cli_logger.abort( cli_logger.abort(
"No usable subnets matching availability zone {} found.\n" f"No usable subnets matching availability zone {azs} found "
"Choose a different availability zone or try " f"for node type {node_type_key}.\nChoose a different "
"manually creating an instance in your specified region " "availability zone or try manually creating an instance in "
"to populate the list of subnets and trying this again.", "your specified region to populate the list of subnets and "
config["provider"]["availability_zone"], "trying this again."
)
elif _are_user_subnets_pruned(subnets):
cli_logger.abort(
f"MISMATCH between specified subnets and Availability Zones! "
"The following Availability Zones were specified in the "
f"`provider section`: {azs}.\n The following subnets for node "
f"type `{node_type_key}` have no matching availability zone: "
f"{list(_get_pruned_subnets(subnets))}."
) )
# Use subnets in only one VPC, so that _configure_security_groups only # Use subnets in only one VPC, so that _configure_security_groups only
@ -501,17 +534,79 @@ def _configure_subnet(config):
# to set up security groups in all of the user's VPCs and set up networking # to set up security groups in all of the user's VPCs and set up networking
# rules to allow traffic between these groups. # rules to allow traffic between these groups.
# See https://github.com/ray-project/ray/pull/14868. # See https://github.com/ray-project/ray/pull/14868.
subnet_ids = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id] first_subnet_vpc_id = subnets[0].vpc_id
subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
if _are_user_subnets_pruned(subnets):
subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
cli_logger.abort(
f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
f"Please ensure that all subnets share the same VPC and retry your "
"request. Subnet VPCs: {}",
subnet_vpcs,
)
return subnets, first_subnet_vpc_id
def _configure_subnet(config):
ec2 = _resource("ec2", config)
# If head or worker security group is specified, filter down to subnets
# belonging to the same VPC as the security group.
sg_ids = []
for node_type in config["available_node_types"].values():
node_config = node_type["node_config"]
sg_ids.extend(node_config.get("SecurityGroupIds", []))
if sg_ids:
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
else:
vpc_id_of_sg = None
# map from node type key -> source of SubnetIds field # map from node type key -> source of SubnetIds field
subnet_src_info = {} subnet_src_info = {}
_set_config_info(subnet_src=subnet_src_info) _set_config_info(subnet_src=subnet_src_info)
all_subnets = list(ec2.subnets.all())
# separate node types with and without user-specified subnets
node_types_subnets = []
node_types_no_subnets = []
for key, node_type in config["available_node_types"].items(): for key, node_type in config["available_node_types"].items():
node_config = node_type["node_config"] if "SubnetIds" in node_type["node_config"]:
if "SubnetIds" not in node_config: node_types_subnets.append((key, node_type))
subnet_src_info[key] = "default"
node_config["SubnetIds"] = subnet_ids
else: else:
subnet_src_info[key] = "config" node_types_no_subnets.append((key, node_type))
vpc_id = None
# iterate over node types with user-specified subnets first...
for key, node_type in node_types_subnets:
node_config = node_type["node_config"]
user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"]))
subnet_ids, vpc_id = _usable_subnet_ids(
user_subnets,
all_subnets,
azs=config["provider"].get("availability_zone"),
vpc_id_of_sg=vpc_id_of_sg,
use_internal_ips=config["provider"].get("use_internal_ips", False),
node_type_key=key,
)
subnet_src_info[key] = "config"
# lock-in a good VPC shared by the last set of user-specified subnets...
if vpc_id and not vpc_id_of_sg:
vpc_id_of_sg = vpc_id
# iterate over node types without user-specified subnets last...
for key, node_type in node_types_no_subnets:
node_config = node_type["node_config"]
subnet_ids, vpc_id = _usable_subnet_ids(
None,
all_subnets,
azs=config["provider"].get("availability_zone"),
vpc_id_of_sg=vpc_id_of_sg,
use_internal_ips=config["provider"].get("use_internal_ips", False),
node_type_key=key,
)
subnet_src_info[key] = "default"
node_config["SubnetIds"] = subnet_ids
return config return config
@ -653,17 +748,27 @@ def _get_or_create_vpc_security_groups(conf, node_types):
} }
def _get_vpc_id_or_die(ec2, subnet_id: str):
subnets = _get_subnets_or_die(ec2, (subnet_id,))
cli_logger.doassert(
len(subnets) == 1,
f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
)
return subnets[0].vpc_id
@lru_cache() @lru_cache()
def _get_vpc_id_or_die(ec2, subnet_id): def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
subnet = list( subnets = list(
ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": [subnet_id]}]) ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
) )
# TODO: better error message # TODO: better error message
cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id) cli_logger.doassert(
assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id) len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
subnet = subnet[0] )
return subnet.vpc_id assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
return subnets
def _get_security_group(config, vpc_id, group_name): def _get_security_group(config, vpc_id, group_name):

View file

@ -46,7 +46,7 @@ available_node_types:
node_config: node_config:
NetworkInterfaces: NetworkInterfaces:
- DeviceIndex: 0 # Primary network interface. - DeviceIndex: 0 # Primary network interface.
SubnetId: subnet-00000000 # Replace with your Subnet ID. SubnetId: subnet-0000000 # Replace with your Subnet ID.
# Head node network interfaces can optionally associate fixed private # Head node network interfaces can optionally associate fixed private
# addresses with the head node. # addresses with the head node.
PrivateIpAddress: 172.31.64.10 # Replace with an IP in your subnet. PrivateIpAddress: 172.31.64.10 # Replace with an IP in your subnet.

View file

@ -1,11 +1,12 @@
import copy import copy
from click.exceptions import ClickException
import pytest import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from ray.autoscaler._private.aws.config import ( from ray.autoscaler._private.aws.config import (
_configure_subnet, _configure_subnet,
_get_vpc_id_or_die, _get_subnets_or_die,
bootstrap_aws, bootstrap_aws,
log_to_cli, log_to_cli,
DEFAULT_AMI, DEFAULT_AMI,
@ -51,7 +52,7 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub) stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)
# describe the subnet in use while determining its vpc # describe the subnet in use while determining its vpc
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET) stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
# given no existing security groups within the VPC... # given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub) stubs.describe_no_security_groups(ec2_client_stub)
# expect to create a security group on the VPC # expect to create a security group on the VPC
@ -80,7 +81,7 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
# given our mocks and an example config file as input... # given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully # expect the config to be loaded, validated, and bootstrapped successfully
config = helpers.bootstrap_aws_example_config_file("example-full.yaml") config = helpers.bootstrap_aws_example_config_file("example-full.yaml")
_get_vpc_id_or_die.cache_clear() _get_subnets_or_die.cache_clear()
# We've filtered down to only one subnet id -- only one of the thousand # We've filtered down to only one subnet id -- only one of the thousand
# subnets generated by ec2.subnets.all() belongs to the right VPC. # subnets generated by ec2.subnets.all() belongs to the right VPC.
@ -90,15 +91,25 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
assert node_config["SecurityGroupIds"] == [DEFAULT_SG["GroupId"]] assert node_config["SecurityGroupIds"] == [DEFAULT_SG["GroupId"]]
def test_create_sg_different_vpc_same_rules(iam_client_stub, ec2_client_stub): @pytest.mark.parametrize(
"correct_az",
[True, False],
)
def test_create_sg_different_vpc_same_rules(
iam_client_stub, ec2_client_stub, correct_az: bool
):
# use default stubs to skip ahead to security group configuration # use default stubs to skip ahead to security group configuration
stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub) stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)
default_subnet = copy.deepcopy(DEFAULT_SUBNET)
if not correct_az:
default_subnet["AvailabilityZone"] = "us-west-2b"
# given head and worker nodes with custom subnets defined... # given head and worker nodes with custom subnets defined...
# expect to first describe the worker subnet ID
stubs.describe_subnets_echo(ec2_client_stub, AUX_SUBNET)
# expect to second describe the head subnet ID # expect to second describe the head subnet ID
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET) stubs.describe_subnets_echo(ec2_client_stub, [default_subnet])
# expect to first describe the worker subnet ID
stubs.describe_subnets_echo(ec2_client_stub, [AUX_SUBNET])
# given no existing security groups within the VPC... # given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub) stubs.describe_no_security_groups(ec2_client_stub)
# expect to first create a security group on the worker node VPC # expect to first create a security group on the worker node VPC
@ -133,7 +144,19 @@ def test_create_sg_different_vpc_same_rules(iam_client_stub, ec2_client_stub):
# given our mocks and an example config file as input... # given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully # expect the config to be loaded, validated, and bootstrapped successfully
config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml") error = None
try:
config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
except ClickException as e:
error = e
_get_subnets_or_die.cache_clear()
if not correct_az:
assert isinstance(error, ClickException), "Did not get a ClickException!"
iam_client_stub._queue.clear()
ec2_client_stub._queue.clear()
return
# expect the bootstrapped config to show different head and worker security # expect the bootstrapped config to show different head and worker security
# groups residing on different subnets # groups residing on different subnets
@ -158,7 +181,7 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_clien
stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub) stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)
# expect to describe the head subnet ID # expect to describe the head subnet ID
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET) stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
# given no existing security groups within the VPC... # given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub) stubs.describe_no_security_groups(ec2_client_stub)
# expect to create a security group on the head node VPC # expect to create a security group on the head node VPC
@ -181,7 +204,7 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_clien
# expect the next read of a head security group property to reload it # expect the next read of a head security group property to reload it
stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES) stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)
_get_vpc_id_or_die.cache_clear() _get_subnets_or_die.cache_clear()
# given our mocks and an example config file as input... # given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully # expect the config to be loaded, validated, and bootstrapped successfully
config = helpers.bootstrap_aws_example_config_file("example-security-group.yaml") config = helpers.bootstrap_aws_example_config_file("example-security-group.yaml")
@ -365,7 +388,7 @@ def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
# test_create_sg_with_custom_inbound_rules_and_name. # test_create_sg_with_custom_inbound_rules_and_name.
# expect to describe the head subnet ID # expect to describe the head subnet ID
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET) stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
# given no existing security groups within the VPC... # given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub) stubs.describe_no_security_groups(ec2_client_stub)
# expect to create a security group on the head node VPC # expect to create a security group on the head node VPC
@ -388,7 +411,7 @@ def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
# expect the next read of a head security group property to reload it # expect the next read of a head security group property to reload it
stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES) stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)
_get_vpc_id_or_die.cache_clear() _get_subnets_or_die.cache_clear()
# given our mocks and the config as input... # given our mocks and the config as input...
# expect the config to be validated and bootstrapped successfully # expect the config to be validated and bootstrapped successfully
@ -525,6 +548,16 @@ def test_network_interfaces(
# use a default stub to skip subnet configuration # use a default stub to skip subnet configuration
stubs.configure_subnet_default(ec2_client_stub) stubs.configure_subnet_default(ec2_client_stub)
stubs.describe_subnets_echo(
ec2_client_stub,
[DEFAULT_SUBNET, {**DEFAULT_SUBNET, "SubnetId": "subnet-11111111"}],
)
stubs.describe_subnets_echo(
ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-22222222"}]
)
stubs.describe_subnets_echo(
ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-33333333"}]
)
# given our mocks and an example config file as input... # given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully # expect the config to be loaded, validated, and bootstrapped successfully

View file

@ -1,3 +1,4 @@
from typing import Dict, List
import ray import ray
import copy import copy
@ -78,13 +79,15 @@ def skip_to_configure_sg(ec2_client_stub, iam_client_stub):
configure_subnet_default(ec2_client_stub) configure_subnet_default(ec2_client_stub)
def describe_subnets_echo(ec2_client_stub, subnet): def describe_subnets_echo(ec2_client_stub, subnets: List[Dict[str, str]]):
ec2_client_stub.add_response( ec2_client_stub.add_response(
"describe_subnets", "describe_subnets",
expected_params={ expected_params={
"Filters": [{"Name": "subnet-id", "Values": [subnet["SubnetId"]]}] "Filters": [
{"Name": "subnet-id", "Values": [s["SubnetId"] for s in subnets]}
]
}, },
service_response={"Subnets": [subnet]}, service_response={"Subnets": subnets},
) )