mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[autoscaler][AWS] Make sure subnets belong to same VPC as user-specified security groups (#13558)
* initial commit * Filter subnets by security groups' VPCs * fix stubs * wip * Fix inbound rule logic. Tests WIP. * wip * unit test * example yaml * Unit test tests for bug being fixed * Update python/ray/tests/aws/utils/constants.py Co-authored-by: Thomas Desrosiers <681004+thomasdesr@users.noreply.github.com> Co-authored-by: Thomas Desrosiers <681004+thomasdesr@users.noreply.github.com>
This commit is contained in:
parent
28cf5f91e3
commit
40234ad631
5 changed files with 137 additions and 4 deletions
|
@ -5,6 +5,7 @@ import itertools
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
import logging
|
||||
|
||||
import boto3
|
||||
|
@ -357,9 +358,23 @@ def _configure_subnet(config):
|
|||
ec2 = _resource("ec2", config)
|
||||
use_internal_ips = config["provider"].get("use_internal_ips", False)
|
||||
|
||||
# If head or worker security group is specified, filter down to subnets
|
||||
# belonging to the same VPC as the security group.
|
||||
sg_ids = (config["head_node"].get("SecurityGroupIds", []) +
|
||||
config["worker_nodes"].get("SecurityGroupIds", []))
|
||||
if sg_ids:
|
||||
vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config)
|
||||
else:
|
||||
vpc_id_of_sg = None
|
||||
|
||||
try:
|
||||
candidate_subnets = ec2.subnets.all()
|
||||
if vpc_id_of_sg:
|
||||
candidate_subnets = [
|
||||
s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg
|
||||
]
|
||||
subnets = sorted(
|
||||
(s for s in ec2.subnets.all() if s.state == "available" and (
|
||||
(s for s in candidate_subnets if s.state == "available" and (
|
||||
use_internal_ips or s.map_public_ip_on_launch)),
|
||||
reverse=True, # sort from Z-A
|
||||
key=lambda subnet: subnet.availability_zone)
|
||||
|
@ -414,6 +429,34 @@ def _configure_subnet(config):
|
|||
return config
|
||||
|
||||
|
||||
def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str:
|
||||
"""Returns the VPC id of the security groups with the provided security
|
||||
group ids.
|
||||
|
||||
Errors if the provided security groups belong to multiple VPCs.
|
||||
Errors if no security group with any of the provided ids is identified.
|
||||
"""
|
||||
sg_ids = list(set(sg_ids))
|
||||
|
||||
ec2 = _resource("ec2", config)
|
||||
filters = [{"Name": "group-id", "Values": sg_ids}]
|
||||
security_groups = ec2.security_groups.filter(Filters=filters)
|
||||
vpc_ids = [sg.vpc_id for sg in security_groups]
|
||||
vpc_ids = list(set(vpc_ids))
|
||||
|
||||
multiple_vpc_msg = "All security groups specified in the cluster config "\
|
||||
"should belong to the same VPC."
|
||||
cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg)
|
||||
assert len(vpc_ids) <= 1, multiple_vpc_msg
|
||||
|
||||
no_sg_msg = "Failed to detect a security group with id equal to any of "\
|
||||
"the configured SecurityGroupIds."
|
||||
cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg)
|
||||
assert len(vpc_ids) > 0, no_sg_msg
|
||||
|
||||
return vpc_ids[0]
|
||||
|
||||
|
||||
def _configure_security_group(config):
|
||||
_set_config_info(
|
||||
head_security_group_src="config", workers_security_group_src="config")
|
||||
|
@ -566,6 +609,13 @@ def _create_security_group(config, vpc_id, group_name):
|
|||
|
||||
def _upsert_security_group_rules(conf, security_groups):
|
||||
sgids = {sg.id for sg in security_groups.values()}
|
||||
|
||||
# Update sgids to include user-specified security groups.
|
||||
# This is necessary if the user specifies the head node type's security
|
||||
# groups but not the worker's, or vice-versa.
|
||||
for node_type in NODE_KIND_CONFIG_KEYS.values():
|
||||
sgids.update(conf[node_type].get("SecurityGroupIds", []))
|
||||
|
||||
# sort security group items for deterministic inbound rule config order
|
||||
# (mainly supports more precise stub-based boto3 unit testing)
|
||||
for node_type, sg in sorted(security_groups.items()):
|
||||
|
@ -583,7 +633,7 @@ def _update_inbound_rules(target_security_group, sgids, config):
|
|||
|
||||
|
||||
def _create_default_inbound_rules(sgids, extended_rules=[]):
|
||||
intracluster_rules = _create_default_instracluster_inbound_rules(sgids)
|
||||
intracluster_rules = _create_default_intracluster_inbound_rules(sgids)
|
||||
ssh_rules = _create_default_ssh_inbound_rules()
|
||||
merged_rules = itertools.chain(
|
||||
intracluster_rules,
|
||||
|
@ -593,7 +643,7 @@ def _create_default_inbound_rules(sgids, extended_rules=[]):
|
|||
return list(merged_rules)
|
||||
|
||||
|
||||
def _create_default_instracluster_inbound_rules(intracluster_sgids):
|
||||
def _create_default_intracluster_inbound_rules(intracluster_sgids):
|
||||
return [{
|
||||
"FromPort": -1,
|
||||
"ToPort": -1,
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
cluster_name: sg
|
||||
|
||||
max_workers: 1
|
||||
|
||||
provider:
|
||||
type: aws
|
||||
region: us-west-2
|
||||
availability_zone: us-west-2a
|
||||
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
# If required, head and worker nodes can exist on subnets in different VPCs and
|
||||
# communicate via VPC peering.
|
||||
|
||||
# VPC peering overview: https://docs.aws.amazon.com/vpc/latest/userguide/vpc-peering.html.
|
||||
# Setup VPC peering: https://docs.aws.amazon.com/vpc/latest/peering/create-vpc-peering-connection.html.
|
||||
# Configure VPC peering route tables: https://docs.aws.amazon.com/vpc/latest/peering/vpc-peering-routing.html.
|
||||
|
||||
# To enable external SSH connectivity, you should also ensure that your VPC
|
||||
# is configured to assign public IPv4 addresses to every EC2 instance
|
||||
# assigned to it.
|
||||
head_node:
|
||||
SecurityGroupIds:
|
||||
- sg-1234abcd # Replace with an actual security group id.
|
||||
|
||||
worker_nodes:
|
||||
SecurityGroupIds:
|
||||
- sg-1234abcd # Replace with an actual security group id.
|
||||
|
||||
|
|
@ -113,6 +113,26 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub,
|
|||
ec2_client_stub.assert_no_pending_responses()
|
||||
|
||||
|
||||
def test_subnet_given_head_and_worker_sg(iam_client_stub, ec2_client_stub):
|
||||
stubs.configure_iam_role_default(iam_client_stub)
|
||||
stubs.configure_key_pair_default(ec2_client_stub)
|
||||
|
||||
# list a security group and a thousand subnets in different vpcs
|
||||
stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
|
||||
stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)
|
||||
|
||||
config = helpers.bootstrap_aws_example_config_file(
|
||||
"example-head-and-worker-security-group.yaml")
|
||||
|
||||
# check that just the single subnet in the right vpc is filled
|
||||
assert config["head_node"]["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]
|
||||
assert config["worker_nodes"]["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]
|
||||
|
||||
# expect no pending responses left in IAM or EC2 client stub queues
|
||||
iam_client_stub.assert_no_pending_responses()
|
||||
ec2_client_stub.assert_no_pending_responses()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -50,6 +50,19 @@ DEFAULT_SUBNET = {
|
|||
"VpcId": "vpc-0000000",
|
||||
}
|
||||
|
||||
|
||||
def subnet_in_vpc(vpc_num):
|
||||
"""Returns a copy of DEFAULT_SUBNET whose VpcId ends with the digits
|
||||
of vpc_num."""
|
||||
subnet = copy.copy(DEFAULT_SUBNET)
|
||||
subnet["VpcId"] = f"vpc-{vpc_num:07d}"
|
||||
return subnet
|
||||
|
||||
|
||||
A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS = [
|
||||
subnet_in_vpc(vpc_num) for vpc_num in range(1, 1000)
|
||||
] + [DEFAULT_SUBNET]
|
||||
|
||||
# Secondary EC2 subnet to expose to tests as required.
|
||||
AUX_SUBNET = {
|
||||
"AvailabilityZone": "us-west-2a",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import ray
|
||||
from ray.tests.aws.utils.mocks import mock_path_exists_key_pair
|
||||
from ray.tests.aws.utils.constants import DEFAULT_INSTANCE_PROFILE, \
|
||||
DEFAULT_KEY_PAIR, DEFAULT_SUBNET
|
||||
DEFAULT_KEY_PAIR, DEFAULT_SUBNET, A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS
|
||||
|
||||
from unittest import mock
|
||||
|
||||
|
@ -41,6 +41,13 @@ def configure_subnet_default(ec2_client_stub):
|
|||
service_response={"Subnets": [DEFAULT_SUBNET]})
|
||||
|
||||
|
||||
def describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub):
|
||||
ec2_client_stub.add_response(
|
||||
"describe_subnets",
|
||||
expected_params={},
|
||||
service_response={"Subnets": A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS})
|
||||
|
||||
|
||||
def skip_to_configure_sg(ec2_client_stub, iam_client_stub):
|
||||
configure_iam_role_default(iam_client_stub)
|
||||
configure_key_pair_default(ec2_client_stub)
|
||||
|
@ -66,6 +73,18 @@ def describe_no_security_groups(ec2_client_stub):
|
|||
service_response={})
|
||||
|
||||
|
||||
def describe_a_security_group(ec2_client_stub, security_group):
|
||||
ec2_client_stub.add_response(
|
||||
"describe_security_groups",
|
||||
expected_params={
|
||||
"Filters": [{
|
||||
"Name": "group-id",
|
||||
"Values": [security_group["GroupId"]]
|
||||
}]
|
||||
},
|
||||
service_response={"SecurityGroups": [security_group]})
|
||||
|
||||
|
||||
def create_sg_echo(ec2_client_stub, security_group):
|
||||
ec2_client_stub.add_response(
|
||||
"create_security_group",
|
||||
|
|
Loading…
Add table
Reference in a new issue