diff --git a/python/ray/autoscaler/_private/aws/config.py b/python/ray/autoscaler/_private/aws/config.py index 79fc57896..4c3a1c448 100644 --- a/python/ray/autoscaler/_private/aws/config.py +++ b/python/ray/autoscaler/_private/aws/config.py @@ -5,6 +5,7 @@ import itertools import json import os import time +from typing import Any, Dict, List import logging import boto3 @@ -357,9 +358,23 @@ def _configure_subnet(config): ec2 = _resource("ec2", config) use_internal_ips = config["provider"].get("use_internal_ips", False) + # If head or worker security group is specified, filter down to subnets + # belonging to the same VPC as the security group. + sg_ids = (config["head_node"].get("SecurityGroupIds", []) + + config["worker_nodes"].get("SecurityGroupIds", [])) + if sg_ids: + vpc_id_of_sg = _get_vpc_id_of_sg(sg_ids, config) + else: + vpc_id_of_sg = None + try: + candidate_subnets = ec2.subnets.all() + if vpc_id_of_sg: + candidate_subnets = [ + s for s in candidate_subnets if s.vpc_id == vpc_id_of_sg + ] subnets = sorted( - (s for s in ec2.subnets.all() if s.state == "available" and ( + (s for s in candidate_subnets if s.state == "available" and ( use_internal_ips or s.map_public_ip_on_launch)), reverse=True, # sort from Z-A key=lambda subnet: subnet.availability_zone) @@ -414,6 +429,34 @@ def _configure_subnet(config): return config +def _get_vpc_id_of_sg(sg_ids: List[str], config: Dict[str, Any]) -> str: + """Returns the VPC id of the security groups with the provided security + group ids. + + Errors if the provided security groups belong to multiple VPCs. + Errors if no security group with any of the provided ids is identified. + """ + sg_ids = list(set(sg_ids)) + + ec2 = _resource("ec2", config) + filters = [{"Name": "group-id", "Values": sg_ids}] + security_groups = ec2.security_groups.filter(Filters=filters) + vpc_ids = [sg.vpc_id for sg in security_groups] + vpc_ids = list(set(vpc_ids)) + + multiple_vpc_msg = "All security groups specified in the cluster config "\ + "should belong to the same VPC." + cli_logger.doassert(len(vpc_ids) <= 1, multiple_vpc_msg) + assert len(vpc_ids) <= 1, multiple_vpc_msg + + no_sg_msg = "Failed to detect a security group with id equal to any of "\ + "the configured SecurityGroupIds." + cli_logger.doassert(len(vpc_ids) > 0, no_sg_msg) + assert len(vpc_ids) > 0, no_sg_msg + + return vpc_ids[0] + + def _configure_security_group(config): _set_config_info( head_security_group_src="config", workers_security_group_src="config") @@ -566,6 +609,13 @@ def _create_security_group(config, vpc_id, group_name): def _upsert_security_group_rules(conf, security_groups): sgids = {sg.id for sg in security_groups.values()} + + # Update sgids to include user-specified security groups. + # This is necessary if the user specifies the head node type's security + # groups but not the worker's, or vice-versa. + for node_type in NODE_KIND_CONFIG_KEYS.values(): + sgids.update(conf[node_type].get("SecurityGroupIds", [])) + # sort security group items for deterministic inbound rule config order # (mainly supports more precise stub-based boto3 unit testing) for node_type, sg in sorted(security_groups.items()): @@ -583,7 +633,7 @@ def _update_inbound_rules(target_security_group, sgids, config): def _create_default_inbound_rules(sgids, extended_rules=[]): - intracluster_rules = _create_default_instracluster_inbound_rules(sgids) + intracluster_rules = _create_default_intracluster_inbound_rules(sgids) ssh_rules = _create_default_ssh_inbound_rules() merged_rules = itertools.chain( intracluster_rules, @@ -593,7 +643,7 @@ def _create_default_inbound_rules(sgids, extended_rules=[]): return list(merged_rules) -def _create_default_instracluster_inbound_rules(intracluster_sgids): +def _create_default_intracluster_inbound_rules(intracluster_sgids): return [{ "FromPort": -1, "ToPort": -1, diff --git a/python/ray/autoscaler/aws/example-head-and-worker-security-group.yaml b/python/ray/autoscaler/aws/example-head-and-worker-security-group.yaml new file mode 100644 index 000000000..b940366a0 --- /dev/null +++ b/python/ray/autoscaler/aws/example-head-and-worker-security-group.yaml @@ -0,0 +1,31 @@ +cluster_name: sg + +max_workers: 1 + +provider: + type: aws + region: us-west-2 + availability_zone: us-west-2a + +auth: + ssh_user: ubuntu + +# If required, head and worker nodes can exist on subnets in different VPCs and +# communicate via VPC peering. + +# VPC peering overview: https://docs.aws.amazon.com/vpc/latest/userguide/vpc-peering.html. +# Setup VPC peering: https://docs.aws.amazon.com/vpc/latest/peering/create-vpc-peering-connection.html. +# Configure VPC peering route tables: https://docs.aws.amazon.com/vpc/latest/peering/vpc-peering-routing.html. + +# To enable external SSH connectivity, you should also ensure that your VPC +# is configured to assign public IPv4 addresses to every EC2 instance +# assigned to it. +head_node: + SecurityGroupIds: + - sg-1234abcd # Replace with an actual security group id. + +worker_nodes: + SecurityGroupIds: + - sg-1234abcd # Replace with an actual security group id. + + diff --git a/python/ray/tests/aws/test_autoscaler_aws.py b/python/ray/tests/aws/test_autoscaler_aws.py index 697c9efb1..52ceb9fb8 100644 --- a/python/ray/tests/aws/test_autoscaler_aws.py +++ b/python/ray/tests/aws/test_autoscaler_aws.py @@ -113,6 +113,26 @@ def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_client_stub.assert_no_pending_responses() +def test_subnet_given_head_and_worker_sg(iam_client_stub, ec2_client_stub): + stubs.configure_iam_role_default(iam_client_stub) + stubs.configure_key_pair_default(ec2_client_stub) + + # list a security group and a thousand subnets in different vpcs + stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG) + stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub) + + config = helpers.bootstrap_aws_example_config_file( + "example-head-and-worker-security-group.yaml") + + # check that just the single subnet in the right vpc is filled + assert config["head_node"]["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]] + assert config["worker_nodes"]["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]] + + # expect no pending responses left in IAM or EC2 client stub queues + iam_client_stub.assert_no_pending_responses() + ec2_client_stub.assert_no_pending_responses() + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/aws/utils/constants.py b/python/ray/tests/aws/utils/constants.py index cdcf5a79c..adc8a5b2a 100644 --- a/python/ray/tests/aws/utils/constants.py +++ b/python/ray/tests/aws/utils/constants.py @@ -50,6 +50,19 @@ DEFAULT_SUBNET = { "VpcId": "vpc-0000000", } + +def subnet_in_vpc(vpc_num): + """Returns a copy of DEFAULT_SUBNET whose VpcId ends with the digits + of vpc_num.""" + subnet = copy.copy(DEFAULT_SUBNET) + subnet["VpcId"] = f"vpc-{vpc_num:07d}" + return subnet + + +A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS = [ + subnet_in_vpc(vpc_num) for vpc_num in range(1, 1000) +] + [DEFAULT_SUBNET] + # Secondary EC2 subnet to expose to tests as required. AUX_SUBNET = { "AvailabilityZone": "us-west-2a", diff --git a/python/ray/tests/aws/utils/stubs.py b/python/ray/tests/aws/utils/stubs.py index 7840447d8..61f1f9ab6 100644 --- a/python/ray/tests/aws/utils/stubs.py +++ b/python/ray/tests/aws/utils/stubs.py @@ -1,7 +1,7 @@ import ray from ray.tests.aws.utils.mocks import mock_path_exists_key_pair from ray.tests.aws.utils.constants import DEFAULT_INSTANCE_PROFILE, \ - DEFAULT_KEY_PAIR, DEFAULT_SUBNET + DEFAULT_KEY_PAIR, DEFAULT_SUBNET, A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS from unittest import mock @@ -41,6 +41,13 @@ def configure_subnet_default(ec2_client_stub): service_response={"Subnets": [DEFAULT_SUBNET]}) +def describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub): + ec2_client_stub.add_response( + "describe_subnets", + expected_params={}, + service_response={"Subnets": A_THOUSAND_SUBNETS_IN_DIFFERENT_VPCS}) + + def skip_to_configure_sg(ec2_client_stub, iam_client_stub): configure_iam_role_default(iam_client_stub) configure_key_pair_default(ec2_client_stub) @@ -66,6 +73,18 @@ def describe_no_security_groups(ec2_client_stub): service_response={}) +def describe_a_security_group(ec2_client_stub, security_group): + ec2_client_stub.add_response( + "describe_security_groups", + expected_params={ + "Filters": [{ + "Name": "group-id", + "Values": [security_group["GroupId"]] + }] + }, + service_response={"SecurityGroups": [security_group]}) + + def create_sg_echo(ec2_client_stub, security_group): ec2_client_stub.add_response( "create_security_group",