mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[autoscaler] Allowing users to provide extra configs for AWS (#7844)
* Allowing users to provide custom key names & security group inbound rules * linting * getting aws credentials passed in * one more thing * one more thing part 2 * formatting * addressing comments * update * update * update * update * update * update * remove tests * rerun tests Co-authored-by: Allen Yin <allenyin@anyscale.io>
This commit is contained in:
parent
630b3b1752
commit
3c91ff1f63
4 changed files with 51 additions and 25 deletions
|
@ -37,13 +37,21 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
|
|||
"Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
|
||||
|
||||
|
||||
def key_pair(i, region):
|
||||
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
|
||||
def key_pair(i, region, key_name):
|
||||
"""
|
||||
If key_name is not None, key_pair will be named after key_name.
|
||||
Returns the ith default (aws_key_pair_name, key_pair_path).
|
||||
"""
|
||||
if i == 0:
|
||||
return ("{}_{}".format(RAY, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
|
||||
return ("{}_{}_{}".format(RAY, i, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
|
||||
key_pair_name = ("{}_{}".format(RAY, region)
|
||||
if key_name is None else key_name)
|
||||
return (key_pair_name,
|
||||
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
|
||||
|
||||
key_pair_name = ("{}_{}_{}".format(RAY, i, region)
|
||||
if key_name is None else key_name + "_key-{}".format(i))
|
||||
return (key_pair_name,
|
||||
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
|
||||
|
||||
|
||||
# Suppress excessive connection dropped logs from boto
|
||||
|
@ -136,7 +144,11 @@ def _configure_key_pair(config):
|
|||
# Try a few times to get or create a good key pair.
|
||||
MAX_NUM_KEYS = 30
|
||||
for i in range(MAX_NUM_KEYS):
|
||||
key_name, key_path = key_pair(i, config["provider"]["region"])
|
||||
|
||||
key_name = config["provider"].get("key_pair", {}).get("key_name")
|
||||
|
||||
key_name, key_path = key_pair(i, config["provider"]["region"],
|
||||
key_name)
|
||||
key = _get_key(key_name, config)
|
||||
|
||||
# Found a good key.
|
||||
|
@ -236,7 +248,7 @@ def _configure_security_group(config):
|
|||
assert security_group, "Failed to create security group"
|
||||
|
||||
if not security_group.ip_permissions:
|
||||
security_group.authorize_ingress(IpPermissions=[{
|
||||
IpPermissions = [{
|
||||
"FromPort": -1,
|
||||
"ToPort": -1,
|
||||
"IpProtocol": "-1",
|
||||
|
@ -250,7 +262,13 @@ def _configure_security_group(config):
|
|||
"IpRanges": [{
|
||||
"CidrIp": "0.0.0.0/0"
|
||||
}]
|
||||
}])
|
||||
}]
|
||||
|
||||
additional_IpPermissions = config["provider"].get(
|
||||
"security_group", {}).get("IpPermissions", [])
|
||||
IpPermissions.extend(additional_IpPermissions)
|
||||
|
||||
security_group.authorize_ingress(IpPermissions=IpPermissions)
|
||||
|
||||
if "SecurityGroupIds" not in config["head_node"]:
|
||||
logger.info(
|
||||
|
@ -359,10 +377,19 @@ def _get_key(key_name, config):
|
|||
|
||||
def _client(name, config):
|
||||
boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
|
||||
return boto3.client(name, config["provider"]["region"], config=boto_config)
|
||||
aws_credentials = config["provider"].get("aws_credentials", {})
|
||||
return boto3.client(
|
||||
name,
|
||||
config["provider"]["region"],
|
||||
config=boto_config,
|
||||
**aws_credentials)
|
||||
|
||||
|
||||
def _resource(name, config):
|
||||
boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
|
||||
aws_credentials = config["provider"].get("aws_credentials", {})
|
||||
return boto3.resource(
|
||||
name, config["provider"]["region"], config=boto_config)
|
||||
name,
|
||||
config["provider"]["region"],
|
||||
config=boto_config,
|
||||
**aws_credentials)
|
||||
|
|
|
@ -34,10 +34,12 @@ def from_aws_format(tags):
|
|||
return tags
|
||||
|
||||
|
||||
def make_ec2_client(region, max_retries):
|
||||
def make_ec2_client(region, max_retries, aws_credentials=None):
|
||||
"""Make client, retrying requests up to `max_retries`."""
|
||||
config = Config(retries={"max_attempts": max_retries})
|
||||
return boto3.resource("ec2", region_name=region, config=config)
|
||||
aws_credentials = aws_credentials or {}
|
||||
return boto3.resource(
|
||||
"ec2", region_name=region, config=config, **aws_credentials)
|
||||
|
||||
|
||||
class AWSNodeProvider(NodeProvider):
|
||||
|
@ -45,10 +47,16 @@ class AWSNodeProvider(NodeProvider):
|
|||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes",
|
||||
True)
|
||||
aws_credentials = provider_config.get("aws_credentials")
|
||||
|
||||
self.ec2 = make_ec2_client(
|
||||
region=provider_config["region"], max_retries=BOTO_MAX_RETRIES)
|
||||
region=provider_config["region"],
|
||||
max_retries=BOTO_MAX_RETRIES,
|
||||
aws_credentials=aws_credentials)
|
||||
self.ec2_fail_fast = make_ec2_client(
|
||||
region=provider_config["region"], max_retries=0)
|
||||
region=provider_config["region"],
|
||||
max_retries=0,
|
||||
aws_credentials=aws_credentials)
|
||||
|
||||
# Try availability zones round-robin, starting from random offset
|
||||
self.subnet_idx = random.randint(0, 100)
|
||||
|
|
|
@ -58,7 +58,7 @@
|
|||
"type": "object",
|
||||
"description": "Cloud-provider specific configuration.",
|
||||
"required": [ "type" ],
|
||||
"additionalProperties": false,
|
||||
"additionalProperties": true,
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
|
@ -128,10 +128,6 @@
|
|||
"type": "object",
|
||||
"description": "k8s autoscaler permissions, if using k8s"
|
||||
},
|
||||
"extra_config": {
|
||||
"type": "object",
|
||||
"description": "provider-specific config"
|
||||
},
|
||||
"cache_stopped_nodes": {
|
||||
"type": "boolean",
|
||||
"description": " Whether to try to reuse previously stopped nodes instead of launching nodes. This will also cause the autoscaler to stop nodes instead of terminating them. Only implemented for AWS."
|
||||
|
|
|
@ -328,11 +328,6 @@ class AutoscalingTest(unittest.TestCase):
|
|||
validate_config(config)
|
||||
del config["blah"]
|
||||
|
||||
config["provider"]["blah"] = "blah"
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config(config)
|
||||
del config["provider"]["blah"]
|
||||
|
||||
del config["provider"]
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config(config)
|
||||
|
|
Loading…
Add table
Reference in a new issue