[autoscaler] Allowing users to provide extra configs for AWS (#7844)

* Allowing users to provide custom key names & security group inbound rules

* linting

* getting aws credentials passed in

* one more thing

* one more thing part 2

* formatting

* addressing comments

* update

* update

* update

* update

* update

* update

* remove tests

* rerun tests

Co-authored-by: Allen Yin <allenyin@anyscale.io>
This commit is contained in:
Allen 2020-04-04 18:36:51 -07:00 committed by GitHub
parent 630b3b1752
commit 3c91ff1f63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 25 deletions

View file

@ -37,13 +37,21 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
"Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
def key_pair(i, region):
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
def key_pair(i, region, key_name):
"""
If key_name is not None, key_pair will be named after key_name.
Returns the ith default (aws_key_pair_name, key_pair_path).
"""
if i == 0:
return ("{}_{}".format(RAY, region),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
return ("{}_{}_{}".format(RAY, i, region),
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
key_pair_name = ("{}_{}".format(RAY, region)
if key_name is None else key_name)
return (key_pair_name,
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
key_pair_name = ("{}_{}_{}".format(RAY, i, region)
if key_name is None else key_name + "_key-{}".format(i))
return (key_pair_name,
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
# Suppress excessive connection dropped logs from boto
@ -136,7 +144,11 @@ def _configure_key_pair(config):
# Try a few times to get or create a good key pair.
MAX_NUM_KEYS = 30
for i in range(MAX_NUM_KEYS):
key_name, key_path = key_pair(i, config["provider"]["region"])
key_name = config["provider"].get("key_pair", {}).get("key_name")
key_name, key_path = key_pair(i, config["provider"]["region"],
key_name)
key = _get_key(key_name, config)
# Found a good key.
@ -236,7 +248,7 @@ def _configure_security_group(config):
assert security_group, "Failed to create security group"
if not security_group.ip_permissions:
security_group.authorize_ingress(IpPermissions=[{
IpPermissions = [{
"FromPort": -1,
"ToPort": -1,
"IpProtocol": "-1",
@ -250,7 +262,13 @@ def _configure_security_group(config):
"IpRanges": [{
"CidrIp": "0.0.0.0/0"
}]
}])
}]
additional_IpPermissions = config["provider"].get(
"security_group", {}).get("IpPermissions", [])
IpPermissions.extend(additional_IpPermissions)
security_group.authorize_ingress(IpPermissions=IpPermissions)
if "SecurityGroupIds" not in config["head_node"]:
logger.info(
@ -359,10 +377,19 @@ def _get_key(key_name, config):
def _client(name, config):
boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
return boto3.client(name, config["provider"]["region"], config=boto_config)
aws_credentials = config["provider"].get("aws_credentials", {})
return boto3.client(
name,
config["provider"]["region"],
config=boto_config,
**aws_credentials)
def _resource(name, config):
boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
aws_credentials = config["provider"].get("aws_credentials", {})
return boto3.resource(
name, config["provider"]["region"], config=boto_config)
name,
config["provider"]["region"],
config=boto_config,
**aws_credentials)

View file

@ -34,10 +34,12 @@ def from_aws_format(tags):
return tags
def make_ec2_client(region, max_retries):
def make_ec2_client(region, max_retries, aws_credentials=None):
"""Make client, retrying requests up to `max_retries`."""
config = Config(retries={"max_attempts": max_retries})
return boto3.resource("ec2", region_name=region, config=config)
aws_credentials = aws_credentials or {}
return boto3.resource(
"ec2", region_name=region, config=config, **aws_credentials)
class AWSNodeProvider(NodeProvider):
@ -45,10 +47,16 @@ class AWSNodeProvider(NodeProvider):
NodeProvider.__init__(self, provider_config, cluster_name)
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes",
True)
aws_credentials = provider_config.get("aws_credentials")
self.ec2 = make_ec2_client(
region=provider_config["region"], max_retries=BOTO_MAX_RETRIES)
region=provider_config["region"],
max_retries=BOTO_MAX_RETRIES,
aws_credentials=aws_credentials)
self.ec2_fail_fast = make_ec2_client(
region=provider_config["region"], max_retries=0)
region=provider_config["region"],
max_retries=0,
aws_credentials=aws_credentials)
# Try availability zones round-robin, starting from random offset
self.subnet_idx = random.randint(0, 100)

View file

@ -58,7 +58,7 @@
"type": "object",
"description": "Cloud-provider specific configuration.",
"required": [ "type" ],
"additionalProperties": false,
"additionalProperties": true,
"properties": {
"type": {
"type": "string",
@ -128,10 +128,6 @@
"type": "object",
"description": "k8s autoscaler permissions, if using k8s"
},
"extra_config": {
"type": "object",
"description": "provider-specific config"
},
"cache_stopped_nodes": {
"type": "boolean",
"description": " Whether to try to reuse previously stopped nodes instead of launching nodes. This will also cause the autoscaler to stop nodes instead of terminating them. Only implemented for AWS."

View file

@ -328,11 +328,6 @@ class AutoscalingTest(unittest.TestCase):
validate_config(config)
del config["blah"]
config["provider"]["blah"] = "blah"
with pytest.raises(ValidationError):
validate_config(config)
del config["provider"]["blah"]
del config["provider"]
with pytest.raises(ValidationError):
validate_config(config)