[AWS] Abort if AZs & SubnetIds mismatch (#22001)

If a user simultaneously selects AZs to use & specifies Subnets not in those AZs, raise an error!
This commit is contained in:
Ian Rodney 2022-04-21 11:07:59 -07:00 committed by GitHub
parent 02b0d82cf8
commit 0c16bbd245
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 54 deletions

View file

@ -6,7 +6,7 @@ import itertools
import json
import os
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set, Tuple
import logging
import boto3
@ -434,23 +434,43 @@ def _key_assert_msg(node_type: str) -> str:
)
def _configure_subnet(config):
ec2 = _resource("ec2", config)
use_internal_ips = config["provider"].get("use_internal_ips", False)
def _usable_subnet_ids(
user_specified_subnets: Optional[List[Any]],
all_subnets: List[Any],
azs: Optional[str],
vpc_id_of_sg: Optional[str],
use_internal_ips: bool,
node_type_key: str,
) -> Tuple[List[str], str]:
"""Prunes subnets down to those that meet the following criteria.
# If head or worker security group is specified, filter down to subnets
# belonging to the same VPC as the security group.
sg_ids = []
for node_type in config["available_node_types"].values():
node_config = node_type["node_config"]
sg_ids.extend(node_config.get("SecurityGroupIds", []))
if sg_ids:
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
else:
vpc_id_of_sg = None
Subnets must be:
* 'Available' according to AWS.
* Public, unless `use_internal_ips` is specified.
* In one of the AZs, if AZs are provided.
* In the given VPC, if a VPC is specified for Security Groups.
Returns:
List[str]: Subnets that are usable.
str: VPC ID of the first subnet.
"""
def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
return user_specified_subnets is not None and len(current_subnets) != len(
user_specified_subnets
)
def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
current_subnet_ids = {s.subnet_id for s in current_subnets}
user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets}
return user_specified_subnet_ids - current_subnet_ids
try:
candidate_subnets = ec2.subnets.all()
candidate_subnets = (
user_specified_subnets
if user_specified_subnets is not None
else all_subnets
)
if vpc_id_of_sg:
candidate_subnets = [
s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
@ -471,16 +491,21 @@ def _configure_subnet(config):
if not subnets:
cli_logger.abort(
"No usable subnets found, try manually creating an instance in "
"your specified region to populate the list of subnets "
"and trying this again.\n"
f"No usable subnets found for node type {node_type_key}, try "
"manually creating an instance in your specified region to "
"populate the list of subnets and trying this again.\n"
"Note that the subnet must map public IPs "
"on instance launch unless you set `use_internal_ips: true` in "
"the `provider` config."
)
elif _are_user_subnets_pruned(subnets):
cli_logger.abort(
f"The specified subnets for node type {node_type_key} are not "
f"usable: {_get_pruned_subnets(subnets)}"
)
if "availability_zone" in config["provider"]:
azs = config["provider"]["availability_zone"].split(",")
if azs is not None:
azs = [az.strip() for az in azs.split(",")]
subnets = [
s
for az in azs # Iterate over AZs first to maintain the ordering
@ -489,11 +514,19 @@ def _configure_subnet(config):
]
if not subnets:
cli_logger.abort(
"No usable subnets matching availability zone {} found.\n"
"Choose a different availability zone or try "
"manually creating an instance in your specified region "
"to populate the list of subnets and trying this again.",
config["provider"]["availability_zone"],
f"No usable subnets matching availability zone {azs} found "
f"for node type {node_type_key}.\nChoose a different "
"availability zone or try manually creating an instance in "
"your specified region to populate the list of subnets and "
"trying this again."
)
elif _are_user_subnets_pruned(subnets):
cli_logger.abort(
f"MISMATCH between specified subnets and Availability Zones! "
"The following Availability Zones were specified in the "
f"`provider section`: {azs}.\n The following subnets for node "
f"type `{node_type_key}` have no matching availability zone: "
f"{list(_get_pruned_subnets(subnets))}."
)
# Use subnets in only one VPC, so that _configure_security_groups only
@ -501,17 +534,79 @@ def _configure_subnet(config):
# to set up security groups in all of the user's VPCs and set up networking
# rules to allow traffic between these groups.
# See https://github.com/ray-project/ray/pull/14868.
subnet_ids = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
first_subnet_vpc_id = subnets[0].vpc_id
subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
if _are_user_subnets_pruned(subnets):
subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
cli_logger.abort(
f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
f"Please ensure that all subnets share the same VPC and retry your "
"request. Subnet VPCs: {}",
subnet_vpcs,
)
return subnets, first_subnet_vpc_id
def _configure_subnet(config):
ec2 = _resource("ec2", config)
# If head or worker security group is specified, filter down to subnets
# belonging to the same VPC as the security group.
sg_ids = []
for node_type in config["available_node_types"].values():
node_config = node_type["node_config"]
sg_ids.extend(node_config.get("SecurityGroupIds", []))
if sg_ids:
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
else:
vpc_id_of_sg = None
# map from node type key -> source of SubnetIds field
subnet_src_info = {}
_set_config_info(subnet_src=subnet_src_info)
all_subnets = list(ec2.subnets.all())
# separate node types with and without user-specified subnets
node_types_subnets = []
node_types_no_subnets = []
for key, node_type in config["available_node_types"].items():
node_config = node_type["node_config"]
if "SubnetIds" not in node_config:
subnet_src_info[key] = "default"
node_config["SubnetIds"] = subnet_ids
if "SubnetIds" in node_type["node_config"]:
node_types_subnets.append((key, node_type))
else:
subnet_src_info[key] = "config"
node_types_no_subnets.append((key, node_type))
vpc_id = None
# iterate over node types with user-specified subnets first...
for key, node_type in node_types_subnets:
node_config = node_type["node_config"]
user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"]))
subnet_ids, vpc_id = _usable_subnet_ids(
user_subnets,
all_subnets,
azs=config["provider"].get("availability_zone"),
vpc_id_of_sg=vpc_id_of_sg,
use_internal_ips=config["provider"].get("use_internal_ips", False),
node_type_key=key,
)
subnet_src_info[key] = "config"
# lock-in a good VPC shared by the last set of user-specified subnets...
if vpc_id and not vpc_id_of_sg:
vpc_id_of_sg = vpc_id
# iterate over node types without user-specified subnets last...
for key, node_type in node_types_no_subnets:
node_config = node_type["node_config"]
subnet_ids, vpc_id = _usable_subnet_ids(
None,
all_subnets,
azs=config["provider"].get("availability_zone"),
vpc_id_of_sg=vpc_id_of_sg,
use_internal_ips=config["provider"].get("use_internal_ips", False),
node_type_key=key,
)
subnet_src_info[key] = "default"
node_config["SubnetIds"] = subnet_ids
return config
@ -653,17 +748,27 @@ def _get_or_create_vpc_security_groups(conf, node_types):
}
def _get_vpc_id_or_die(ec2, subnet_id: str):
subnets = _get_subnets_or_die(ec2, (subnet_id,))
cli_logger.doassert(
len(subnets) == 1,
f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
)
return subnets[0].vpc_id
@lru_cache()
def _get_vpc_id_or_die(ec2, subnet_id):
subnet = list(
ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": [subnet_id]}])
def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
subnets = list(
ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
)
# TODO: better error message
cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id)
assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id)
subnet = subnet[0]
return subnet.vpc_id
cli_logger.doassert(
len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
)
assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
return subnets
def _get_security_group(config, vpc_id, group_name):

View file

@ -46,7 +46,7 @@ available_node_types:
node_config:
NetworkInterfaces:
- DeviceIndex: 0 # Primary network interface.
SubnetId: subnet-00000000 # Replace with your Subnet ID.
SubnetId: subnet-0000000 # Replace with your Subnet ID.
# Head node network interfaces can optionally associate fixed private
# addresses with the head node.
PrivateIpAddress: 172.31.64.10 # Replace with an IP in your subnet.

View file

@ -1,11 +1,12 @@
import copy
from click.exceptions import ClickException
import pytest
from unittest.mock import Mock, patch
from ray.autoscaler._private.aws.config import (
_configure_subnet,
_get_vpc_id_or_die,
_get_subnets_or_die,
bootstrap_aws,
log_to_cli,
DEFAULT_AMI,
@ -51,7 +52,7 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)
# describe the subnet in use while determining its vpc
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
# given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub)
# expect to create a security group on the VPC
@ -80,7 +81,7 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
# given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully
config = helpers.bootstrap_aws_example_config_file("example-full.yaml")
_get_vpc_id_or_die.cache_clear()
_get_subnets_or_die.cache_clear()
# We've filtered down to only one subnet id -- only one of the thousand
# subnets generated by ec2.subnets.all() belongs to the right VPC.
@ -90,15 +91,25 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
assert node_config["SecurityGroupIds"] == [DEFAULT_SG["GroupId"]]
def test_create_sg_different_vpc_same_rules(iam_client_stub, ec2_client_stub):
@pytest.mark.parametrize(
"correct_az",
[True, False],
)
def test_create_sg_different_vpc_same_rules(
iam_client_stub, ec2_client_stub, correct_az: bool
):
# use default stubs to skip ahead to security group configuration
stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)
default_subnet = copy.deepcopy(DEFAULT_SUBNET)
if not correct_az:
default_subnet["AvailabilityZone"] = "us-west-2b"
# given head and worker nodes with custom subnets defined...
# expect to first describe the worker subnet ID
stubs.describe_subnets_echo(ec2_client_stub, AUX_SUBNET)
# expect to second describe the head subnet ID
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
stubs.describe_subnets_echo(ec2_client_stub, [default_subnet])
# expect to first describe the worker subnet ID
stubs.describe_subnets_echo(ec2_client_stub, [AUX_SUBNET])
# given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub)
# expect to first create a security group on the worker node VPC
@ -133,7 +144,19 @@ def test_create_sg_different_vpc_same_rules(iam_client_stub, ec2_client_stub):
# given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully
config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
error = None
try:
config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
except ClickException as e:
error = e
_get_subnets_or_die.cache_clear()
if not correct_az:
assert isinstance(error, ClickException), "Did not get a ClickException!"
iam_client_stub._queue.clear()
ec2_client_stub._queue.clear()
return
# expect the bootstrapped config to show different head and worker security
# groups residing on different subnets
@ -158,7 +181,7 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_clien
stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)
# expect to describe the head subnet ID
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
# given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub)
# expect to create a security group on the head node VPC
@ -181,7 +204,7 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_clien
# expect the next read of a head security group property to reload it
stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)
_get_vpc_id_or_die.cache_clear()
_get_subnets_or_die.cache_clear()
# given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully
config = helpers.bootstrap_aws_example_config_file("example-security-group.yaml")
@ -365,7 +388,7 @@ def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
# test_create_sg_with_custom_inbound_rules_and_name.
# expect to describe the head subnet ID
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
# given no existing security groups within the VPC...
stubs.describe_no_security_groups(ec2_client_stub)
# expect to create a security group on the head node VPC
@ -388,7 +411,7 @@ def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
# expect the next read of a head security group property to reload it
stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)
_get_vpc_id_or_die.cache_clear()
_get_subnets_or_die.cache_clear()
# given our mocks and the config as input...
# expect the config to be validated and bootstrapped successfully
@ -525,6 +548,16 @@ def test_network_interfaces(
# use a default stub to skip subnet configuration
stubs.configure_subnet_default(ec2_client_stub)
stubs.describe_subnets_echo(
ec2_client_stub,
[DEFAULT_SUBNET, {**DEFAULT_SUBNET, "SubnetId": "subnet-11111111"}],
)
stubs.describe_subnets_echo(
ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-22222222"}]
)
stubs.describe_subnets_echo(
ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-33333333"}]
)
# given our mocks and an example config file as input...
# expect the config to be loaded, validated, and bootstrapped successfully

View file

@ -1,3 +1,4 @@
from typing import Dict, List
import ray
import copy
@ -78,13 +79,15 @@ def skip_to_configure_sg(ec2_client_stub, iam_client_stub):
configure_subnet_default(ec2_client_stub)
def describe_subnets_echo(ec2_client_stub, subnet):
def describe_subnets_echo(ec2_client_stub, subnets: List[Dict[str, str]]):
ec2_client_stub.add_response(
"describe_subnets",
expected_params={
"Filters": [{"Name": "subnet-id", "Values": [subnet["SubnetId"]]}]
"Filters": [
{"Name": "subnet-id", "Values": [s["SubnetId"] for s in subnets]}
]
},
service_response={"Subnets": [subnet]},
service_response={"Subnets": subnets},
)