diff --git a/python/ray/autoscaler/_private/aws/config.py b/python/ray/autoscaler/_private/aws/config.py index a2da1a7a1..d3297f943 100644 --- a/python/ray/autoscaler/_private/aws/config.py +++ b/python/ray/autoscaler/_private/aws/config.py @@ -782,19 +782,21 @@ def _configure_subnets_and_groups_from_network_interfaces(node_cfg): conflict_keys = ["SubnetId", "SubnetIds", "SecurityGroupIds"] if any(conflict in node_cfg for conflict in conflict_keys): raise ValueError( - "If NetworkInterfaces are defined, subnets and security groups" + "If NetworkInterfaces are defined, subnets and security groups " "must ONLY be given in each NetworkInterface.") - if not all(_subnets_in_network_config(node_cfg)): + subnets = _subnets_in_network_config(node_cfg) + if not all(subnets): raise ValueError( "NetworkInterfaces are defined but at least one is missing a " "subnet. Please ensure all interfaces have a subnet assigned.") - if not all(_security_groups_in_network_config(node_cfg)): + security_groups = _security_groups_in_network_config(node_cfg) + if not all(security_groups): raise ValueError( "NetworkInterfaces are defined but at least one is missing a " "security group. Please ensure all interfaces have a security " "group assigned.") - node_cfg["SubnetIds"] = _subnets_in_network_config(node_cfg) - node_cfg["SecurityGroupIds"] = _security_groups_in_network_config(node_cfg) + node_cfg["SubnetIds"] = subnets + node_cfg["SecurityGroupIds"] = list(itertools.chain(*security_groups)) def _subnets_in_network_config(config): @@ -804,10 +806,7 @@ def _subnets_in_network_config(config): def _security_groups_in_network_config(config): - lists = [ - ni.get("Groups", []) for ni in config.get("NetworkInterfaces", []) - ] - return list(itertools.chain(*lists)) + return [ni.get("Groups", []) for ni in config.get("NetworkInterfaces", [])] def _client(name, config): diff --git a/python/ray/tests/aws/test_autoscaler_aws.py b/python/ray/tests/aws/test_autoscaler_aws.py index 1cb803a4a..8a5bdda18 100644 --- a/python/ray/tests/aws/test_autoscaler_aws.py +++ b/python/ray/tests/aws/test_autoscaler_aws.py @@ -524,6 +524,56 @@ def test_network_interfaces(ec2_client_stub, iam_client_stub, ec2_client_stub_max_retries.assert_no_pending_responses() +def test_network_interface_conflict_keys(): + # If NetworkInterfaces are defined, SubnetId and SecurityGroupIds + # can't be specified in the same node type config. + conflict_kv_pairs = [("SubnetId", "subnet-0000000"), + ("SubnetIds", ["subnet-0000000", "subnet-1111111"]), + ("SecurityGroupIds", ["sg-1234abcd", "sg-dcba4321"])] + expected_error_msg = "If NetworkInterfaces are defined, subnets and " \ + "security groups must ONLY be given in each " \ + "NetworkInterface." + for conflict_kv_pair in conflict_kv_pairs: + config = helpers.load_aws_example_config_file( + "example-network-interfaces.yaml") + head_name = config["head_node_type"] + head_node_cfg = config["available_node_types"][head_name][ + "node_config"] + head_node_cfg[conflict_kv_pair[0]] = conflict_kv_pair[1] + with pytest.raises(ValueError, match=expected_error_msg): + helpers.bootstrap_aws_config(config) + + +def test_network_interface_missing_subnet(): + # If NetworkInterfaces are defined, each must have a subnet ID + expected_error_msg = "NetworkInterfaces are defined but at least one is " \ + "missing a subnet. Please ensure all interfaces " \ + "have a subnet assigned." + config = helpers.load_aws_example_config_file( + "example-network-interfaces.yaml") + for name, node_type in config["available_node_types"].items(): + node_cfg = node_type["node_config"] + for network_interface_cfg in node_cfg["NetworkInterfaces"]: + network_interface_cfg.pop("SubnetId") + with pytest.raises(ValueError, match=expected_error_msg): + helpers.bootstrap_aws_config(config) + + +def test_network_interface_missing_security_group(): + # If NetworkInterfaces are defined, each must have security groups + expected_error_msg = "NetworkInterfaces are defined but at least one is " \ + "missing a security group. Please ensure all " \ + "interfaces have a security group assigned." + config = helpers.load_aws_example_config_file( + "example-network-interfaces.yaml") + for name, node_type in config["available_node_types"].items(): + node_cfg = node_type["node_config"] + for network_interface_cfg in node_cfg["NetworkInterfaces"]: + network_interface_cfg.pop("Groups") + with pytest.raises(ValueError, match=expected_error_msg): + helpers.bootstrap_aws_config(config) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))