mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AWS] Abort if AZs & SubnetIds mismatch (#22001)
If a user simultaneously selects AZs to use & specifies Subnets not in those AZs, raise an error!
This commit is contained in:
parent
02b0d82cf8
commit
0c16bbd245
4 changed files with 195 additions and 54 deletions
|
@ -6,7 +6,7 @@ import itertools
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
import logging
|
||||
|
||||
import boto3
|
||||
|
@ -434,23 +434,43 @@ def _key_assert_msg(node_type: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
def _configure_subnet(config):
|
||||
ec2 = _resource("ec2", config)
|
||||
use_internal_ips = config["provider"].get("use_internal_ips", False)
|
||||
def _usable_subnet_ids(
|
||||
user_specified_subnets: Optional[List[Any]],
|
||||
all_subnets: List[Any],
|
||||
azs: Optional[str],
|
||||
vpc_id_of_sg: Optional[str],
|
||||
use_internal_ips: bool,
|
||||
node_type_key: str,
|
||||
) -> Tuple[List[str], str]:
|
||||
"""Prunes subnets down to those that meet the following criteria.
|
||||
|
||||
# If head or worker security group is specified, filter down to subnets
|
||||
# belonging to the same VPC as the security group.
|
||||
sg_ids = []
|
||||
for node_type in config["available_node_types"].values():
|
||||
node_config = node_type["node_config"]
|
||||
sg_ids.extend(node_config.get("SecurityGroupIds", []))
|
||||
if sg_ids:
|
||||
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
|
||||
else:
|
||||
vpc_id_of_sg = None
|
||||
Subnets must be:
|
||||
* 'Available' according to AWS.
|
||||
* Public, unless `use_internal_ips` is specified.
|
||||
* In one of the AZs, if AZs are provided.
|
||||
* In the given VPC, if a VPC is specified for Security Groups.
|
||||
|
||||
Returns:
|
||||
List[str]: Subnets that are usable.
|
||||
str: VPC ID of the first subnet.
|
||||
"""
|
||||
|
||||
def _are_user_subnets_pruned(current_subnets: List[Any]) -> bool:
|
||||
return user_specified_subnets is not None and len(current_subnets) != len(
|
||||
user_specified_subnets
|
||||
)
|
||||
|
||||
def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
|
||||
current_subnet_ids = {s.subnet_id for s in current_subnets}
|
||||
user_specified_subnet_ids = {s.subnet_id for s in user_specified_subnets}
|
||||
return user_specified_subnet_ids - current_subnet_ids
|
||||
|
||||
try:
|
||||
candidate_subnets = ec2.subnets.all()
|
||||
candidate_subnets = (
|
||||
user_specified_subnets
|
||||
if user_specified_subnets is not None
|
||||
else all_subnets
|
||||
)
|
||||
if vpc_id_of_sg:
|
||||
candidate_subnets = [
|
||||
s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
|
||||
|
@ -471,16 +491,21 @@ def _configure_subnet(config):
|
|||
|
||||
if not subnets:
|
||||
cli_logger.abort(
|
||||
"No usable subnets found, try manually creating an instance in "
|
||||
"your specified region to populate the list of subnets "
|
||||
"and trying this again.\n"
|
||||
f"No usable subnets found for node type {node_type_key}, try "
|
||||
"manually creating an instance in your specified region to "
|
||||
"populate the list of subnets and trying this again.\n"
|
||||
"Note that the subnet must map public IPs "
|
||||
"on instance launch unless you set `use_internal_ips: true` in "
|
||||
"the `provider` config."
|
||||
)
|
||||
elif _are_user_subnets_pruned(subnets):
|
||||
cli_logger.abort(
|
||||
f"The specified subnets for node type {node_type_key} are not "
|
||||
f"usable: {_get_pruned_subnets(subnets)}"
|
||||
)
|
||||
|
||||
if "availability_zone" in config["provider"]:
|
||||
azs = config["provider"]["availability_zone"].split(",")
|
||||
if azs is not None:
|
||||
azs = [az.strip() for az in azs.split(",")]
|
||||
subnets = [
|
||||
s
|
||||
for az in azs # Iterate over AZs first to maintain the ordering
|
||||
|
@ -489,11 +514,19 @@ def _configure_subnet(config):
|
|||
]
|
||||
if not subnets:
|
||||
cli_logger.abort(
|
||||
"No usable subnets matching availability zone {} found.\n"
|
||||
"Choose a different availability zone or try "
|
||||
"manually creating an instance in your specified region "
|
||||
"to populate the list of subnets and trying this again.",
|
||||
config["provider"]["availability_zone"],
|
||||
f"No usable subnets matching availability zone {azs} found "
|
||||
f"for node type {node_type_key}.\nChoose a different "
|
||||
"availability zone or try manually creating an instance in "
|
||||
"your specified region to populate the list of subnets and "
|
||||
"trying this again."
|
||||
)
|
||||
elif _are_user_subnets_pruned(subnets):
|
||||
cli_logger.abort(
|
||||
f"MISMATCH between specified subnets and Availability Zones! "
|
||||
"The following Availability Zones were specified in the "
|
||||
f"`provider section`: {azs}.\n The following subnets for node "
|
||||
f"type `{node_type_key}` have no matching availability zone: "
|
||||
f"{list(_get_pruned_subnets(subnets))}."
|
||||
)
|
||||
|
||||
# Use subnets in only one VPC, so that _configure_security_groups only
|
||||
|
@ -501,17 +534,79 @@ def _configure_subnet(config):
|
|||
# to set up security groups in all of the user's VPCs and set up networking
|
||||
# rules to allow traffic between these groups.
|
||||
# See https://github.com/ray-project/ray/pull/14868.
|
||||
subnet_ids = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
|
||||
first_subnet_vpc_id = subnets[0].vpc_id
|
||||
subnets = [s.subnet_id for s in subnets if s.vpc_id == subnets[0].vpc_id]
|
||||
if _are_user_subnets_pruned(subnets):
|
||||
subnet_vpcs = {s.subnet_id: s.vpc_id for s in user_specified_subnets}
|
||||
cli_logger.abort(
|
||||
f"Subnets specified in more than one VPC for node type `{node_type_key}`! "
|
||||
f"Please ensure that all subnets share the same VPC and retry your "
|
||||
"request. Subnet VPCs: {}",
|
||||
subnet_vpcs,
|
||||
)
|
||||
return subnets, first_subnet_vpc_id
|
||||
|
||||
|
||||
def _configure_subnet(config):
|
||||
ec2 = _resource("ec2", config)
|
||||
|
||||
# If head or worker security group is specified, filter down to subnets
|
||||
# belonging to the same VPC as the security group.
|
||||
sg_ids = []
|
||||
for node_type in config["available_node_types"].values():
|
||||
node_config = node_type["node_config"]
|
||||
sg_ids.extend(node_config.get("SecurityGroupIds", []))
|
||||
if sg_ids:
|
||||
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
|
||||
else:
|
||||
vpc_id_of_sg = None
|
||||
|
||||
# map from node type key -> source of SubnetIds field
|
||||
subnet_src_info = {}
|
||||
_set_config_info(subnet_src=subnet_src_info)
|
||||
all_subnets = list(ec2.subnets.all())
|
||||
# separate node types with and without user-specified subnets
|
||||
node_types_subnets = []
|
||||
node_types_no_subnets = []
|
||||
for key, node_type in config["available_node_types"].items():
|
||||
node_config = node_type["node_config"]
|
||||
if "SubnetIds" not in node_config:
|
||||
subnet_src_info[key] = "default"
|
||||
node_config["SubnetIds"] = subnet_ids
|
||||
if "SubnetIds" in node_type["node_config"]:
|
||||
node_types_subnets.append((key, node_type))
|
||||
else:
|
||||
subnet_src_info[key] = "config"
|
||||
node_types_no_subnets.append((key, node_type))
|
||||
|
||||
vpc_id = None
|
||||
|
||||
# iterate over node types with user-specified subnets first...
|
||||
for key, node_type in node_types_subnets:
|
||||
node_config = node_type["node_config"]
|
||||
user_subnets = _get_subnets_or_die(ec2, tuple(node_config["SubnetIds"]))
|
||||
subnet_ids, vpc_id = _usable_subnet_ids(
|
||||
user_subnets,
|
||||
all_subnets,
|
||||
azs=config["provider"].get("availability_zone"),
|
||||
vpc_id_of_sg=vpc_id_of_sg,
|
||||
use_internal_ips=config["provider"].get("use_internal_ips", False),
|
||||
node_type_key=key,
|
||||
)
|
||||
subnet_src_info[key] = "config"
|
||||
|
||||
# lock-in a good VPC shared by the last set of user-specified subnets...
|
||||
if vpc_id and not vpc_id_of_sg:
|
||||
vpc_id_of_sg = vpc_id
|
||||
|
||||
# iterate over node types without user-specified subnets last...
|
||||
for key, node_type in node_types_no_subnets:
|
||||
node_config = node_type["node_config"]
|
||||
subnet_ids, vpc_id = _usable_subnet_ids(
|
||||
None,
|
||||
all_subnets,
|
||||
azs=config["provider"].get("availability_zone"),
|
||||
vpc_id_of_sg=vpc_id_of_sg,
|
||||
use_internal_ips=config["provider"].get("use_internal_ips", False),
|
||||
node_type_key=key,
|
||||
)
|
||||
subnet_src_info[key] = "default"
|
||||
node_config["SubnetIds"] = subnet_ids
|
||||
|
||||
return config
|
||||
|
||||
|
@ -653,17 +748,27 @@ def _get_or_create_vpc_security_groups(conf, node_types):
|
|||
}
|
||||
|
||||
|
||||
def _get_vpc_id_or_die(ec2, subnet_id: str):
|
||||
subnets = _get_subnets_or_die(ec2, (subnet_id,))
|
||||
cli_logger.doassert(
|
||||
len(subnets) == 1,
|
||||
f"Expected 1 subnet with ID `{subnet_id}` but found {len(subnets)}",
|
||||
)
|
||||
return subnets[0].vpc_id
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_vpc_id_or_die(ec2, subnet_id):
|
||||
subnet = list(
|
||||
ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": [subnet_id]}])
|
||||
def _get_subnets_or_die(ec2, subnet_ids: Tuple[str]):
|
||||
subnets = list(
|
||||
ec2.subnets.filter(Filters=[{"Name": "subnet-id", "Values": list(subnet_ids)}])
|
||||
)
|
||||
|
||||
# TODO: better error message
|
||||
cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id)
|
||||
assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id)
|
||||
subnet = subnet[0]
|
||||
return subnet.vpc_id
|
||||
cli_logger.doassert(
|
||||
len(subnets) == len(subnet_ids), "Not all subnet IDs found: {}", subnet_ids
|
||||
)
|
||||
assert len(subnets) == len(subnet_ids), "Subnet ID not found: {}".format(subnet_ids)
|
||||
return subnets
|
||||
|
||||
|
||||
def _get_security_group(config, vpc_id, group_name):
|
||||
|
|
|
@ -46,7 +46,7 @@ available_node_types:
|
|||
node_config:
|
||||
NetworkInterfaces:
|
||||
- DeviceIndex: 0 # Primary network interface.
|
||||
SubnetId: subnet-00000000 # Replace with your Subnet ID.
|
||||
SubnetId: subnet-0000000 # Replace with your Subnet ID.
|
||||
# Head node network interfaces can optionally associate fixed private
|
||||
# addresses with the head node.
|
||||
PrivateIpAddress: 172.31.64.10 # Replace with an IP in your subnet.
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import copy
|
||||
|
||||
from click.exceptions import ClickException
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from ray.autoscaler._private.aws.config import (
|
||||
_configure_subnet,
|
||||
_get_vpc_id_or_die,
|
||||
_get_subnets_or_die,
|
||||
bootstrap_aws,
|
||||
log_to_cli,
|
||||
DEFAULT_AMI,
|
||||
|
@ -51,7 +52,7 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
|
|||
stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)
|
||||
|
||||
# describe the subnet in use while determining its vpc
|
||||
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
|
||||
stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
|
||||
# given no existing security groups within the VPC...
|
||||
stubs.describe_no_security_groups(ec2_client_stub)
|
||||
# expect to create a security group on the VPC
|
||||
|
@ -80,7 +81,7 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
|
|||
# given our mocks and an example config file as input...
|
||||
# expect the config to be loaded, validated, and bootstrapped successfully
|
||||
config = helpers.bootstrap_aws_example_config_file("example-full.yaml")
|
||||
_get_vpc_id_or_die.cache_clear()
|
||||
_get_subnets_or_die.cache_clear()
|
||||
|
||||
# We've filtered down to only one subnet id -- only one of the thousand
|
||||
# subnets generated by ec2.subnets.all() belongs to the right VPC.
|
||||
|
@ -90,15 +91,25 @@ def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
|
|||
assert node_config["SecurityGroupIds"] == [DEFAULT_SG["GroupId"]]
|
||||
|
||||
|
||||
def test_create_sg_different_vpc_same_rules(iam_client_stub, ec2_client_stub):
|
||||
@pytest.mark.parametrize(
|
||||
"correct_az",
|
||||
[True, False],
|
||||
)
|
||||
def test_create_sg_different_vpc_same_rules(
|
||||
iam_client_stub, ec2_client_stub, correct_az: bool
|
||||
):
|
||||
# use default stubs to skip ahead to security group configuration
|
||||
stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)
|
||||
|
||||
default_subnet = copy.deepcopy(DEFAULT_SUBNET)
|
||||
if not correct_az:
|
||||
default_subnet["AvailabilityZone"] = "us-west-2b"
|
||||
|
||||
# given head and worker nodes with custom subnets defined...
|
||||
# expect to first describe the worker subnet ID
|
||||
stubs.describe_subnets_echo(ec2_client_stub, AUX_SUBNET)
|
||||
# expect to second describe the head subnet ID
|
||||
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
|
||||
stubs.describe_subnets_echo(ec2_client_stub, [default_subnet])
|
||||
# expect to first describe the worker subnet ID
|
||||
stubs.describe_subnets_echo(ec2_client_stub, [AUX_SUBNET])
|
||||
# given no existing security groups within the VPC...
|
||||
stubs.describe_no_security_groups(ec2_client_stub)
|
||||
# expect to first create a security group on the worker node VPC
|
||||
|
@ -133,7 +144,19 @@ def test_create_sg_different_vpc_same_rules(iam_client_stub, ec2_client_stub):
|
|||
|
||||
# given our mocks and an example config file as input...
|
||||
# expect the config to be loaded, validated, and bootstrapped successfully
|
||||
config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
|
||||
error = None
|
||||
try:
|
||||
config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
|
||||
except ClickException as e:
|
||||
error = e
|
||||
|
||||
_get_subnets_or_die.cache_clear()
|
||||
|
||||
if not correct_az:
|
||||
assert isinstance(error, ClickException), "Did not get a ClickException!"
|
||||
iam_client_stub._queue.clear()
|
||||
ec2_client_stub._queue.clear()
|
||||
return
|
||||
|
||||
# expect the bootstrapped config to show different head and worker security
|
||||
# groups residing on different subnets
|
||||
|
@ -158,7 +181,7 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_clien
|
|||
stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)
|
||||
|
||||
# expect to describe the head subnet ID
|
||||
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
|
||||
stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
|
||||
# given no existing security groups within the VPC...
|
||||
stubs.describe_no_security_groups(ec2_client_stub)
|
||||
# expect to create a security group on the head node VPC
|
||||
|
@ -181,7 +204,7 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_clien
|
|||
# expect the next read of a head security group property to reload it
|
||||
stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)
|
||||
|
||||
_get_vpc_id_or_die.cache_clear()
|
||||
_get_subnets_or_die.cache_clear()
|
||||
# given our mocks and an example config file as input...
|
||||
# expect the config to be loaded, validated, and bootstrapped successfully
|
||||
config = helpers.bootstrap_aws_example_config_file("example-security-group.yaml")
|
||||
|
@ -365,7 +388,7 @@ def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
|
|||
# test_create_sg_with_custom_inbound_rules_and_name.
|
||||
|
||||
# expect to describe the head subnet ID
|
||||
stubs.describe_subnets_echo(ec2_client_stub, DEFAULT_SUBNET)
|
||||
stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
|
||||
# given no existing security groups within the VPC...
|
||||
stubs.describe_no_security_groups(ec2_client_stub)
|
||||
# expect to create a security group on the head node VPC
|
||||
|
@ -388,7 +411,7 @@ def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
|
|||
# expect the next read of a head security group property to reload it
|
||||
stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)
|
||||
|
||||
_get_vpc_id_or_die.cache_clear()
|
||||
_get_subnets_or_die.cache_clear()
|
||||
|
||||
# given our mocks and the config as input...
|
||||
# expect the config to be validated and bootstrapped successfully
|
||||
|
@ -525,6 +548,16 @@ def test_network_interfaces(
|
|||
|
||||
# use a default stub to skip subnet configuration
|
||||
stubs.configure_subnet_default(ec2_client_stub)
|
||||
stubs.describe_subnets_echo(
|
||||
ec2_client_stub,
|
||||
[DEFAULT_SUBNET, {**DEFAULT_SUBNET, "SubnetId": "subnet-11111111"}],
|
||||
)
|
||||
stubs.describe_subnets_echo(
|
||||
ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-22222222"}]
|
||||
)
|
||||
stubs.describe_subnets_echo(
|
||||
ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-33333333"}]
|
||||
)
|
||||
|
||||
# given our mocks and an example config file as input...
|
||||
# expect the config to be loaded, validated, and bootstrapped successfully
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Dict, List
|
||||
import ray
|
||||
import copy
|
||||
|
||||
|
@ -78,13 +79,15 @@ def skip_to_configure_sg(ec2_client_stub, iam_client_stub):
|
|||
configure_subnet_default(ec2_client_stub)
|
||||
|
||||
|
||||
def describe_subnets_echo(ec2_client_stub, subnet):
|
||||
def describe_subnets_echo(ec2_client_stub, subnets: List[Dict[str, str]]):
|
||||
ec2_client_stub.add_response(
|
||||
"describe_subnets",
|
||||
expected_params={
|
||||
"Filters": [{"Name": "subnet-id", "Values": [subnet["SubnetId"]]}]
|
||||
"Filters": [
|
||||
{"Name": "subnet-id", "Values": [s["SubnetId"] for s in subnets]}
|
||||
]
|
||||
},
|
||||
service_response={"Subnets": [subnet]},
|
||||
service_response={"Subnets": subnets},
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue