diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 364e1ebd1..b103937ac 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -37,13 +37,21 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \ "Boto3 version >= 1.4.8 required, try `pip install -U boto3`" -def key_pair(i, region): - """Returns the ith default (aws_key_pair_name, key_pair_path).""" +def key_pair(i, region, key_name): + """ + If key_name is not None, key_pair will be named after key_name. + Returns the ith default (aws_key_pair_name, key_pair_path). + """ if i == 0: - return ("{}_{}".format(RAY, region), - os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region))) - return ("{}_{}_{}".format(RAY, i, region), - os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region))) + key_pair_name = ("{}_{}".format(RAY, region) + if key_name is None else key_name) + return (key_pair_name, + os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name))) + + key_pair_name = ("{}_{}_{}".format(RAY, i, region) + if key_name is None else key_name + "_key-{}".format(i)) + return (key_pair_name, + os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name))) # Suppress excessive connection dropped logs from boto @@ -136,7 +144,11 @@ def _configure_key_pair(config): # Try a few times to get or create a good key pair. MAX_NUM_KEYS = 30 for i in range(MAX_NUM_KEYS): - key_name, key_path = key_pair(i, config["provider"]["region"]) + + key_name = config["provider"].get("key_pair", {}).get("key_name") + + key_name, key_path = key_pair(i, config["provider"]["region"], + key_name) key = _get_key(key_name, config) # Found a good key. @@ -236,7 +248,7 @@ def _configure_security_group(config): assert security_group, "Failed to create security group" if not security_group.ip_permissions: - security_group.authorize_ingress(IpPermissions=[{ + IpPermissions = [{ "FromPort": -1, "ToPort": -1, "IpProtocol": "-1", @@ -250,7 +262,13 @@ def _configure_security_group(config): "IpRanges": [{ "CidrIp": "0.0.0.0/0" }] - }]) + }] + + additional_IpPermissions = config["provider"].get( + "security_group", {}).get("IpPermissions", []) + IpPermissions.extend(additional_IpPermissions) + + security_group.authorize_ingress(IpPermissions=IpPermissions) if "SecurityGroupIds" not in config["head_node"]: logger.info( @@ -359,10 +377,19 @@ def _get_key(key_name, config): def _client(name, config): boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES}) - return boto3.client(name, config["provider"]["region"], config=boto_config) + aws_credentials = config["provider"].get("aws_credentials", {}) + return boto3.client( + name, + config["provider"]["region"], + config=boto_config, + **aws_credentials) def _resource(name, config): boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES}) + aws_credentials = config["provider"].get("aws_credentials", {}) return boto3.resource( - name, config["provider"]["region"], config=boto_config) + name, + config["provider"]["region"], + config=boto_config, + **aws_credentials) diff --git a/python/ray/autoscaler/aws/node_provider.py b/python/ray/autoscaler/aws/node_provider.py index 0d275d22d..c19ec1ac5 100644 --- a/python/ray/autoscaler/aws/node_provider.py +++ b/python/ray/autoscaler/aws/node_provider.py @@ -34,10 +34,12 @@ def from_aws_format(tags): return tags -def make_ec2_client(region, max_retries): +def make_ec2_client(region, max_retries, aws_credentials=None): """Make client, retrying requests up to `max_retries`.""" config = Config(retries={"max_attempts": max_retries}) - return boto3.resource("ec2", region_name=region, config=config) + aws_credentials = aws_credentials or {} + return boto3.resource( + "ec2", region_name=region, config=config, **aws_credentials) class AWSNodeProvider(NodeProvider): @@ -45,10 +47,16 @@ class AWSNodeProvider(NodeProvider): NodeProvider.__init__(self, provider_config, cluster_name) self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True) + aws_credentials = provider_config.get("aws_credentials") + self.ec2 = make_ec2_client( - region=provider_config["region"], max_retries=BOTO_MAX_RETRIES) + region=provider_config["region"], + max_retries=BOTO_MAX_RETRIES, + aws_credentials=aws_credentials) self.ec2_fail_fast = make_ec2_client( - region=provider_config["region"], max_retries=0) + region=provider_config["region"], + max_retries=0, + aws_credentials=aws_credentials) # Try availability zones round-robin, starting from random offset self.subnet_idx = random.randint(0, 100) diff --git a/python/ray/autoscaler/ray-schema.json b/python/ray/autoscaler/ray-schema.json index cd17133e9..d48f4f231 100644 --- a/python/ray/autoscaler/ray-schema.json +++ b/python/ray/autoscaler/ray-schema.json @@ -58,7 +58,7 @@ "type": "object", "description": "Cloud-provider specific configuration.", "required": [ "type" ], - "additionalProperties": false, + "additionalProperties": true, "properties": { "type": { "type": "string", @@ -128,10 +128,6 @@ "type": "object", "description": "k8s autoscaler permissions, if using k8s" }, - "extra_config": { - "type": "object", - "description": "provider-specific config" - }, "cache_stopped_nodes": { "type": "boolean", "description": " Whether to try to reuse previously stopped nodes instead of launching nodes. This will also cause the autoscaler to stop nodes instead of terminating them. Only implemented for AWS." diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 7ef722d88..ba7fd39ab 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -328,11 +328,6 @@ class AutoscalingTest(unittest.TestCase): validate_config(config) del config["blah"] - config["provider"]["blah"] = "blah" - with pytest.raises(ValidationError): - validate_config(config) - del config["provider"]["blah"] - del config["provider"] with pytest.raises(ValidationError): validate_config(config)