mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[autoscaler] GCP node provider (#2061)
* Google Cloud Platform scaffolding * Add minimal gcp config example * Add googleapiclient discoveries, update gcp.config constants * Rename and update gcp.config key pair name function * Implement gcp.config._configure_project * Fix the create project get project flow * Implement gcp.config._configure_iam_role * Implement service account iam binding * Implement gcp.config._configure_key_pair * Implement rsa key pair generation * Implement gcp.config._configure_subnet * Save work-in-progress gcp.config._configure_firewall_rules. These are likely to be not needed at all. Saving them if we happen to need them later. * Remove unnecessary firewall configuration * Update example-minimal.yaml configuration * Add new wait_for_compute_operation, rename old wait_for_operation * Temporarily rename autoscaler tags due to gcp incompatibility * Implement initial gcp.node_provider.nodes * Still missing filter support * Implement initial gcp.node_provider.create_node * Implement another compute wait operation (wait_For_compute_zone_operation). TODO: figure out if we can remove the function. * Implement initial gcp.node_provider._node and node status functions * Implement initial gcp.node_provider.terminate_node * Implement node tagging and ip getter methods for nodes * Temporarily rename tags due to gcp incompatibility * Tiny tweaks for autoscaler.updater * Remove unused config from gcp node_provider * Add new example-full example to gcp, update load_gcp_example_config * Implement label filtering for gcp.node_provider.nodes * Revert unnecessary change in ssh command * Revert "Temporarily rename tags due to gcp incompatibility" This reverts commit e2fe634c5d11d705c0f5d3e76c80c37394bb23fb. * Revert "Temporarily rename autoscaler tags due to gcp incompatibility" This reverts commit c938ee435f4b75854a14e78242ad7f1d1ed8ad4b. * Refactor autoscaler tagging to support multiple tag specs * Remove missing cryptography imports * Update quote function import * Fix threading issue in gcp.config with the compute discovery object * Add gcs support for log_sync * Fix the labels/tags naming discrepancy * Add expanduser to file_mounts hashing * Fix gcp.node_provider.internal_ip * Add uuid to node name * Remove 'set -i' from updater ssh command * Also add TODO with the context and reason for the change. * Update ssh key creation in autoscaler.gcp.config * Fix wait_for_compute_zone_operation's threading issue Google discovery api's compute object is not thread safe, and thus needs to be recreated for each thread. This moves the `wait_for_compute_zone_operation` under `autoscaler.gcp.config`, and adds compute as its argument. * Address pr feedback from @ericl * Expand local file mount paths in NodeUpdater * Add ssh_user name to key names * Update updater ssh to attempt 'set -i' and fall back if that fails * Update gcp/example-full.yaml * Fix wait crm operation in gcp.config * Update gcp/example-minimal.yaml to match aws/example-minimal.yaml * Fix gcp/example-full.yaml comment indentation * Add gcp/example-full.yaml to setup files * Update example-full.yaml command * Revert "Refactor autoscaler tagging to support multiple tag specs" This reverts commit 9cf48409ca2e5b66f800153853072c706fa502f6. * Update tag spec to only use characters [0-9a-z_-] * Change the tag values to conform gcp spec * Add project_id in the ssh key name * Replace '_' with '-' in autoscaler tag names * Revert "Update updater ssh to attempt 'set -i' and fall back if that fails" This reverts commit 23a0066c5254449e49746bd5e43b94b66f32bfb4. * Revert "Remove 'set -i' from updater ssh command" This reverts commit 5fa034cdf79fa7f8903691518c0d75699c630172. * Add fallback to `set -i` in force_interactive command * Update autoscaler tests to match current implementation * Update GCPNodeProvider.create_node to include hash in instance name * Add support for creating multiple instance on one create_node call * Clean TODOs * Update styles * Replace single quotes with double quotes * Some minor indentation fixes etc. * Remove unnecessary comment. Fix indentation. * Yapfify files that fail flake8 test * Yapfify more files * Update project_id handling in gcp node provider * temporary yapf mod * Revert "temporary yapf mod" This reverts commit b6744e4e15d4d936d1a14f4bf155ed1d3bb14126. * Fix autoscaler/updater.py lint error, remove unused variable
This commit is contained in:
parent
117107cb15
commit
74dc14d1fc
13 changed files with 936 additions and 62 deletions
|
@ -22,8 +22,9 @@ from ray.autoscaler.node_provider import get_node_provider, \
|
|||
get_default_config
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
from ray.autoscaler.docker import dockerize_if_needed
|
||||
from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_RAY_RUNTIME_CONFIG, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE, TAG_NAME
|
||||
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
|
||||
TAG_RAY_NODE_NAME)
|
||||
import ray.services as services
|
||||
|
||||
REQUIRED, OPTIONAL = True, False
|
||||
|
@ -58,6 +59,7 @@ CLUSTER_CONFIG_SCHEMA = {
|
|||
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
|
||||
"module": (str,
|
||||
OPTIONAL), # module, if using external node provider
|
||||
"project_id": (None, OPTIONAL), # gcp project id, if using gcp
|
||||
},
|
||||
REQUIRED),
|
||||
|
||||
|
@ -244,6 +246,14 @@ class StandardAutoscaler(object):
|
|||
self.last_update_time = 0.0
|
||||
self.update_interval_s = update_interval_s
|
||||
|
||||
# Expand local file_mounts to allow ~ in the paths. This can't be done
|
||||
# earlier when the config is written since we might be on different
|
||||
# platform and the expansion would result in wrong path.
|
||||
self.config["file_mounts"] = {
|
||||
remote: os.path.expanduser(local)
|
||||
for remote, local in self.config["file_mounts"].items()
|
||||
}
|
||||
|
||||
for local_path in self.config["file_mounts"].values():
|
||||
assert os.path.exists(local_path)
|
||||
|
||||
|
@ -254,8 +264,8 @@ class StandardAutoscaler(object):
|
|||
self.reload_config(errors_fatal=False)
|
||||
self._update()
|
||||
except Exception as e:
|
||||
print("StandardAutoscaler: Error during autoscaling: {}",
|
||||
traceback.format_exc())
|
||||
print("StandardAutoscaler: Error during autoscaling: {}"
|
||||
"".format(traceback.format_exc()))
|
||||
self.num_failures += 1
|
||||
if self.num_failures > self.max_failures:
|
||||
print("*** StandardAutoscaler: Too many errors, abort. ***")
|
||||
|
@ -446,9 +456,10 @@ class StandardAutoscaler(object):
|
|||
num_before = len(self.workers())
|
||||
self.provider.create_node(
|
||||
self.config["worker_nodes"], {
|
||||
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: "Worker",
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized",
|
||||
TAG_RAY_NODE_NAME: "ray-{}-worker".format(
|
||||
self.config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: "worker",
|
||||
TAG_RAY_NODE_STATUS: "uninitialized",
|
||||
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
|
||||
}, count)
|
||||
if len(self.workers()) <= num_before:
|
||||
|
@ -456,7 +467,7 @@ class StandardAutoscaler(object):
|
|||
|
||||
def workers(self):
|
||||
return self.provider.nodes(tag_filters={
|
||||
TAG_RAY_NODE_TYPE: "Worker",
|
||||
TAG_RAY_NODE_TYPE: "worker",
|
||||
})
|
||||
|
||||
def debug_string(self, nodes=None):
|
||||
|
@ -565,7 +576,7 @@ def hash_runtime_conf(file_mounts, extra_objs):
|
|||
with open(os.path.join(dirpath, name), "rb") as f:
|
||||
hasher.update(f.read())
|
||||
else:
|
||||
with open(path, 'r') as f:
|
||||
with open(os.path.expanduser(path), "r") as f:
|
||||
hasher.update(f.read().encode("utf-8"))
|
||||
|
||||
hasher.update(json.dumps(sorted(file_mounts.items())).encode("utf-8"))
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray.autoscaler.autoscaler import validate_config, hash_runtime_conf, \
|
|||
hash_launch_conf, fillout_defaults
|
||||
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_NAME
|
||||
TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
|
||||
|
||||
|
@ -57,7 +57,7 @@ def teardown_cluster(config_file, yes):
|
|||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "Head",
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
}
|
||||
for node in provider.nodes(head_node_tags):
|
||||
print("Terminating head node {}".format(node))
|
||||
|
@ -76,7 +76,7 @@ def get_or_create_head_node(config, no_restart, yes):
|
|||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "Head",
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
}
|
||||
nodes = provider.nodes(head_node_tags)
|
||||
if len(nodes) > 0:
|
||||
|
@ -98,7 +98,8 @@ def get_or_create_head_node(config, no_restart, yes):
|
|||
provider.terminate_node(head_node)
|
||||
print("Launching new head node...")
|
||||
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
|
||||
head_node_tags[TAG_NAME] = "ray-{}-head".format(config["cluster_name"])
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
config["cluster_name"])
|
||||
provider.create_node(config["head_node"], head_node_tags, 1)
|
||||
|
||||
nodes = provider.nodes(head_node_tags)
|
||||
|
@ -185,7 +186,7 @@ def get_head_node_ip(config_file):
|
|||
config = yaml.load(open(config_file).read())
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: "Head",
|
||||
TAG_RAY_NODE_TYPE: "head",
|
||||
}
|
||||
nodes = provider.nodes(head_node_tags)
|
||||
if len(nodes) > 0:
|
||||
|
|
0
python/ray/autoscaler/gcp/__init__.py
Normal file
0
python/ray/autoscaler/gcp/__init__.py
Normal file
427
python/ray/autoscaler/gcp/config.py
Normal file
427
python/ray/autoscaler/gcp/config.py
Normal file
|
@ -0,0 +1,427 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from googleapiclient import discovery, errors
|
||||
|
||||
crm = discovery.build("cloudresourcemanager", "v1")
|
||||
iam = discovery.build("iam", "v1")
|
||||
compute = discovery.build("compute", "v1")
|
||||
|
||||
VERSION = "v1"
|
||||
|
||||
RAY = "ray-autoscaler"
|
||||
DEFAULT_SERVICE_ACCOUNT_ID = RAY + "-sa-" + VERSION
|
||||
SERVICE_ACCOUNT_EMAIL_TEMPLATE = (
|
||||
"{account_id}@{project_id}.iam.gserviceaccount.com")
|
||||
DEFAULT_SERVICE_ACCOUNT_CONFIG = {
|
||||
"displayName": "Ray Autoscaler Service Account ({})".format(VERSION),
|
||||
}
|
||||
DEFAULT_SERVICE_ACCOUNT_ROLES = ("roles/storage.objectAdmin",
|
||||
"roles/compute.admin")
|
||||
|
||||
MAX_POLLS = 12
|
||||
POLL_INTERVAL = 5
|
||||
|
||||
|
||||
def wait_for_crm_operation(operation):
|
||||
"""Poll for cloud resource manager operation until finished."""
|
||||
print("Waiting for operation {} to finish...".format(operation))
|
||||
|
||||
for _ in range(MAX_POLLS):
|
||||
result = crm.operations().get(name=operation["name"]).execute()
|
||||
if "error" in result:
|
||||
raise Exception(result["error"])
|
||||
|
||||
if "done" in result and result["done"]:
|
||||
print("Done.")
|
||||
break
|
||||
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wait_for_compute_global_operation(project_name, operation):
|
||||
"""Poll for global compute operation until finished."""
|
||||
print("Waiting for operation {} to finish...".format(operation["name"]))
|
||||
|
||||
for _ in range(MAX_POLLS):
|
||||
result = compute.globalOperations().get(
|
||||
project=project_name,
|
||||
operation=operation["name"],
|
||||
).execute()
|
||||
if "error" in result:
|
||||
raise Exception(result["error"])
|
||||
|
||||
if result["status"] == "DONE":
|
||||
print("Done.")
|
||||
break
|
||||
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def key_pair_name(i, region, project_id, ssh_user):
|
||||
"""Returns the ith default gcp_key_pair_name."""
|
||||
key_name = "{}_gcp_{}_{}_{}".format(RAY, region, project_id, ssh_user, i)
|
||||
return key_name
|
||||
|
||||
|
||||
def key_pair_paths(key_name):
|
||||
"""Returns public and private key paths for a given key_name."""
|
||||
public_key_path = os.path.expanduser("~/.ssh/{}.pub".format(key_name))
|
||||
private_key_path = os.path.expanduser("~/.ssh/{}.pem".format(key_name))
|
||||
return public_key_path, private_key_path
|
||||
|
||||
|
||||
def generate_rsa_key_pair():
|
||||
"""Create public and private ssh-keys."""
|
||||
|
||||
key = rsa.generate_private_key(
|
||||
backend=default_backend(), public_exponent=65537, key_size=2048)
|
||||
|
||||
public_key = key.public_key().public_bytes(
|
||||
serialization.Encoding.OpenSSH,
|
||||
serialization.PublicFormat.OpenSSH).decode("utf-8")
|
||||
|
||||
pem = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption()).decode("utf-8")
|
||||
|
||||
return public_key, pem
|
||||
|
||||
|
||||
def bootstrap_gcp(config):
|
||||
config = _configure_project(config)
|
||||
config = _configure_iam_role(config)
|
||||
config = _configure_key_pair(config)
|
||||
config = _configure_subnet(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_project(config):
|
||||
"""Setup a Google Cloud Platform Project.
|
||||
|
||||
Google Compute Platform organizes all the resources, such as storage
|
||||
buckets, users, and instances under projects. This is different from
|
||||
aws ec2 where everything is global.
|
||||
"""
|
||||
project_id = config["provider"].get("project_id")
|
||||
assert config["provider"]["project_id"] is not None, (
|
||||
"'project_id' must be set in the 'provider' section of the autoscaler"
|
||||
" config. Notice that the project id must be globally unique.")
|
||||
|
||||
project = _get_project(project_id)
|
||||
|
||||
if project is None:
|
||||
# Project not found, try creating it
|
||||
_create_project(project_id)
|
||||
project = _get_project(project_id)
|
||||
|
||||
assert project is not None, "Failed to create project"
|
||||
assert project["lifecycleState"] == "ACTIVE", (
|
||||
"Project status needs to be ACTIVE, got {}".format(
|
||||
project["lifecycleState"]))
|
||||
|
||||
config["provider"]["project_id"] = project["projectId"]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_iam_role(config):
|
||||
"""Setup a gcp service account with IAM roles.
|
||||
|
||||
Creates a gcp service acconut and binds IAM roles which allow it to control
|
||||
control storage/compute services. Specifically, the head node needs to have
|
||||
an IAM role that allows it to create further gce instances and store items
|
||||
in google cloud storage.
|
||||
|
||||
TODO: Allow the name/id of the service account to be configured
|
||||
"""
|
||||
email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format(
|
||||
account_id=DEFAULT_SERVICE_ACCOUNT_ID,
|
||||
project_id=config["provider"]["project_id"])
|
||||
service_account = _get_service_account(email, config)
|
||||
|
||||
if service_account is None:
|
||||
print("Creating new service account {}".format(
|
||||
DEFAULT_SERVICE_ACCOUNT_ID))
|
||||
|
||||
service_account = _create_service_account(
|
||||
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config)
|
||||
|
||||
assert service_account is not None, "Failed to create service account"
|
||||
|
||||
_add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES)
|
||||
|
||||
config["head_node"]["serviceAccounts"] = [{
|
||||
"email": service_account["email"],
|
||||
# NOTE: The amount of access is determined by the scope + IAM
|
||||
# role of the service account. Even if the cloud-platform scope
|
||||
# gives (scope) access to the whole cloud-platform, the service
|
||||
# account is limited by the IAM rights specified below.
|
||||
"scopes": ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
}]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_key_pair(config):
|
||||
"""Configure SSH access, using an existing key pair if possible.
|
||||
|
||||
Creates a project-wide ssh key that can be used to access all the instances
|
||||
unless explicitly prohibited by instance config.
|
||||
|
||||
The ssh-keys created by ray are of format:
|
||||
|
||||
[USERNAME]:ssh-rsa [KEY_VALUE] [USERNAME]
|
||||
|
||||
where:
|
||||
|
||||
[USERNAME] is the user for the SSH key, specified in the config.
|
||||
[KEY_VALUE] is the public SSH key value.
|
||||
"""
|
||||
|
||||
if "ssh_private_key" in config["auth"]:
|
||||
return config
|
||||
|
||||
ssh_user = config["auth"]["ssh_user"]
|
||||
|
||||
project = compute.projects().get(
|
||||
project=config["provider"]["project_id"]).execute()
|
||||
|
||||
# Key pairs associated with project meta data. The key pairs are general,
|
||||
# and not just ssh keys.
|
||||
ssh_keys_str = next(
|
||||
(item for item in project["commonInstanceMetadata"].get("items", [])
|
||||
if item["key"] == "ssh-keys"), {}).get("value", "")
|
||||
|
||||
ssh_keys = ssh_keys_str.split("\n") if ssh_keys_str else []
|
||||
|
||||
# Try a few times to get or create a good key pair.
|
||||
key_found = False
|
||||
for i in range(10):
|
||||
key_name = key_pair_name(i, config["provider"]["region"],
|
||||
config["provider"]["project_id"], ssh_user)
|
||||
public_key_path, private_key_path = key_pair_paths(key_name)
|
||||
|
||||
for ssh_key in ssh_keys:
|
||||
key_parts = ssh_key.split(" ")
|
||||
if len(key_parts) != 3:
|
||||
continue
|
||||
|
||||
if key_parts[2] == ssh_user and os.path.exists(private_key_path):
|
||||
# Found a key
|
||||
key_found = True
|
||||
break
|
||||
|
||||
# Create a key since it doesn't exist locally or in GCP
|
||||
if not key_found and not os.path.exists(private_key_path):
|
||||
print("Creating new key pair {}".format(key_name))
|
||||
public_key, private_key = generate_rsa_key_pair()
|
||||
|
||||
_create_project_ssh_key_pair(project, public_key, ssh_user)
|
||||
|
||||
with open(private_key_path, "w") as f:
|
||||
f.write(private_key)
|
||||
os.chmod(private_key_path, 0o600)
|
||||
|
||||
with open(public_key_path, "w") as f:
|
||||
f.write(public_key)
|
||||
|
||||
key_found = True
|
||||
|
||||
break
|
||||
|
||||
if key_found:
|
||||
break
|
||||
|
||||
assert key_found, "SSH keypair for user {} not found for {}".format(
|
||||
ssh_user, private_key_path)
|
||||
assert os.path.exists(private_key_path), (
|
||||
"Private key file {} not found for user {}"
|
||||
"".format(private_key_path, ssh_user))
|
||||
|
||||
print("Private key not specified in config, using {}"
|
||||
"".format(private_key_path))
|
||||
|
||||
config["auth"]["ssh_private_key"] = private_key_path
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_subnet(config):
|
||||
"""Pick a reasonable subnet if not specified by the config."""
|
||||
|
||||
subnets = _list_subnets(config)
|
||||
|
||||
if not subnets:
|
||||
raise NotImplementedError("Should be able to create subnet.")
|
||||
|
||||
# TODO: make sure that we have usable subnet. Maybe call
|
||||
# compute.subnetworks().listUsable? For some reason it didn't
|
||||
# work out-of-the-box
|
||||
default_subnet = subnets[0]
|
||||
|
||||
if "networkInterfaces" not in config["head_node"]:
|
||||
config["head_node"]["networkInterfaces"] = [{
|
||||
"subnetwork": default_subnet["selfLink"],
|
||||
"accessConfigs": [{
|
||||
"name": "External NAT",
|
||||
"type": "ONE_TO_ONE_NAT",
|
||||
}],
|
||||
}]
|
||||
|
||||
if "networkInterfaces" not in config["worker_nodes"]:
|
||||
config["worker_nodes"]["networkInterfaces"] = [{
|
||||
"subnetwork": default_subnet["selfLink"],
|
||||
"accessConfigs": [{
|
||||
"name": "External NAT",
|
||||
"type": "ONE_TO_ONE_NAT",
|
||||
}],
|
||||
}]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _list_subnets(config):
|
||||
response = compute.subnetworks().list(
|
||||
project=config["provider"]["project_id"],
|
||||
region=config["provider"]["region"]).execute()
|
||||
|
||||
return response["items"]
|
||||
|
||||
|
||||
def _get_subnet(config, subnet_id):
|
||||
subnet = compute.subnetworks().get(
|
||||
project=config["provider"]["project_id"],
|
||||
region=config["provider"]["region"],
|
||||
subnetwork=subnet_id,
|
||||
).execute()
|
||||
|
||||
return subnet
|
||||
|
||||
|
||||
def _get_project(project_id):
|
||||
try:
|
||||
project = crm.projects().get(projectId=project_id).execute()
|
||||
except errors.HttpError as e:
|
||||
if e.resp.status != 403:
|
||||
raise
|
||||
project = None
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def _create_project(project_id):
|
||||
operation = crm.projects().create(body={
|
||||
"projectId": project_id,
|
||||
"name": project_id
|
||||
}).execute()
|
||||
|
||||
result = wait_for_crm_operation(operation)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_service_account(account, config):
|
||||
project_id = config["provider"]["project_id"]
|
||||
full_name = ("projects/{project_id}/serviceAccounts/{account}"
|
||||
"".format(project_id=project_id, account=account))
|
||||
try:
|
||||
service_account = iam.projects().serviceAccounts().get(
|
||||
name=full_name).execute()
|
||||
except errors.HttpError as e:
|
||||
if e.resp.status != 404:
|
||||
raise
|
||||
service_account = None
|
||||
|
||||
return service_account
|
||||
|
||||
|
||||
def _create_service_account(account_id, account_config, config):
|
||||
project_id = config["provider"]["project_id"]
|
||||
|
||||
service_account = iam.projects().serviceAccounts().create(
|
||||
name="projects/{project_id}".format(project_id=project_id),
|
||||
body={
|
||||
"accountId": account_id,
|
||||
"serviceAccount": account_config,
|
||||
}).execute()
|
||||
|
||||
return service_account
|
||||
|
||||
|
||||
def _add_iam_policy_binding(service_account, roles):
|
||||
"""Add new IAM roles for the service account."""
|
||||
project_id = service_account["projectId"]
|
||||
email = service_account["email"]
|
||||
member_id = "serviceAccount:" + email
|
||||
|
||||
policy = crm.projects().getIamPolicy(resource=project_id).execute()
|
||||
|
||||
for role in roles:
|
||||
role_exists = False
|
||||
for binding in policy["bindings"]:
|
||||
if binding["role"] == role:
|
||||
if member_id not in binding["members"]:
|
||||
binding["members"].append(member_id)
|
||||
role_exists = True
|
||||
|
||||
if not role_exists:
|
||||
policy["bindings"].append({
|
||||
"members": [member_id],
|
||||
"role": role,
|
||||
})
|
||||
|
||||
result = crm.projects().setIamPolicy(
|
||||
resource=project_id, body={
|
||||
"policy": policy,
|
||||
}).execute()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _create_project_ssh_key_pair(project, public_key, ssh_user):
|
||||
"""Inserts an ssh-key into project commonInstanceMetadata"""
|
||||
|
||||
key_parts = public_key.split(" ")
|
||||
|
||||
# Sanity checks to make sure that the generated key matches expectation
|
||||
assert len(key_parts) == 2, key_parts
|
||||
assert key_parts[0] == "ssh-rsa", key_parts
|
||||
|
||||
new_ssh_meta = "{ssh_user}:ssh-rsa {key_value} {ssh_user}".format(
|
||||
ssh_user=ssh_user, key_value=key_parts[1])
|
||||
|
||||
common_instance_metadata = project["commonInstanceMetadata"]
|
||||
items = common_instance_metadata.get("items", [])
|
||||
|
||||
ssh_keys_i = next(
|
||||
(i for i, item in enumerate(items) if item["key"] == "ssh-keys"), None)
|
||||
|
||||
if ssh_keys_i is None:
|
||||
items.append({"key": "ssh-keys", "value": new_ssh_meta})
|
||||
else:
|
||||
ssh_keys = items[ssh_keys_i]
|
||||
ssh_keys["value"] += "\n" + new_ssh_meta
|
||||
items[ssh_keys_i] = ssh_keys
|
||||
|
||||
common_instance_metadata["items"] = items
|
||||
|
||||
operation = compute.projects().setCommonInstanceMetadata(
|
||||
project=project["name"], body=common_instance_metadata).execute()
|
||||
|
||||
response = wait_for_compute_global_operation(project["name"], operation)
|
||||
|
||||
return response
|
161
python/ray/autoscaler/gcp/example-full.yaml
Normal file
161
python/ray/autoscaler/gcp/example-full.yaml
Normal file
|
@ -0,0 +1,161 @@
|
|||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: default
|
||||
|
||||
# The minimum number of workers nodes to launch in addition to the head
|
||||
# node. This number should be >= 0.
|
||||
min_workers: 0
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers.
|
||||
max_workers: 2
|
||||
|
||||
# This executes all commands on all nodes in the docker container,
|
||||
# and opens all the necessary ports to support the Ray cluster.
|
||||
# Empty string means disabled.
|
||||
docker:
|
||||
image: "" # e.g., tensorflow/tensorflow:1.5.0-py3
|
||||
container_name: "" # e.g. ray_docker
|
||||
|
||||
|
||||
# The autoscaler will scale up the cluster to this target fraction of resource
|
||||
# usage. For example, if a cluster of 10 nodes is 100% busy and
|
||||
# target_utilization is 0.8, it would resize the cluster to 13. This fraction
|
||||
# can be decreased to increase the aggressiveness of upscaling.
|
||||
# This value must be less than 1.0 for scaling to happen.
|
||||
target_utilization_fraction: 0.8
|
||||
|
||||
# If a node is idle for this many minutes, it will be removed.
|
||||
idle_timeout_minutes: 5
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: gcp
|
||||
region: us-west1
|
||||
availability_zone: us-west1-a
|
||||
project_id: null # Globally unique project id
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
# By default Ray creates a new private keypair, but you can also use your own.
|
||||
# If you do so, make sure to also set "KeyName" in the head and worker node
|
||||
# configurations below. This requires that you have added the key into the
|
||||
# project wide meta-data.
|
||||
# ssh_private_key: /path/to/your/key.pem
|
||||
|
||||
# Provider-specific config for the head node, e.g. instance type. By default
|
||||
# Ray will auto-configure unspecified fields such as subnets and ssh-keys.
|
||||
# For more documentation on available fields, see:
|
||||
# https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
|
||||
head_node:
|
||||
machineType: n1-standard-2
|
||||
disks:
|
||||
- boot: true
|
||||
autoDelete: true
|
||||
type: PERSISTENT
|
||||
initializeParams:
|
||||
diskSizeGb: 50
|
||||
# See https://cloud.google.com/compute/docs/images for more images
|
||||
sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-1604-lts # Ubuntu
|
||||
|
||||
# Additional options can be found in in the compute docs at
|
||||
# https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
|
||||
|
||||
worker_nodes:
|
||||
machineType: n1-standard-2
|
||||
disks:
|
||||
- boot: true
|
||||
autoDelete: true
|
||||
type: PERSISTENT
|
||||
initializeParams:
|
||||
diskSizeGb: 50
|
||||
# See https://cloud.google.com/compute/docs/images for more images
|
||||
sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-1604-lts # Ubuntu
|
||||
# Run workers on preemtible instance by default.
|
||||
# Comment this out to use on-demand.
|
||||
scheduling:
|
||||
- preemptible: true
|
||||
|
||||
# Additional options can be found in in the compute docs at
|
||||
# https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
|
||||
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
|
||||
# List of shell commands to run to set up nodes.
|
||||
setup_commands:
|
||||
# Consider uncommenting these if you also want to run apt-get commands during setup
|
||||
# - sudo pkill -9 apt-get || true
|
||||
# - sudo pkill -9 dpkg || true
|
||||
# - sudo dpkg --configure -a
|
||||
|
||||
# Install basics.
|
||||
- sudo apt-get update
|
||||
- >-
|
||||
sudo apt-get install -y
|
||||
cmake
|
||||
pkg-config
|
||||
build-essential
|
||||
autoconf
|
||||
curl
|
||||
libtool
|
||||
unzip
|
||||
flex
|
||||
bison
|
||||
python
|
||||
# Install Anaconda.
|
||||
- >-
|
||||
wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -O ~/anaconda3.sh
|
||||
|| true
|
||||
- bash ~/anaconda3.sh -b -p ~/anaconda3 || true
|
||||
- rm ~/anaconda3.sh
|
||||
- echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc
|
||||
|
||||
# Build Ray.
|
||||
# Note: if you're developing Ray, you probably want to create a boot-disk
|
||||
# that has your Ray repo pre-cloned. Then, you can replace the pip installs
|
||||
# below with a git checkout <your_sha> (and possibly a recompile).
|
||||
- echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc
|
||||
- >-
|
||||
pip install
|
||||
google-api-python-client==1.6.7
|
||||
cython==0.27.3
|
||||
# - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.4.0-cp27-cp27mu-manylinux1_x86_64.whl
|
||||
# - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.4.0-cp35-cp35m-manylinux1_x86_64.whl
|
||||
# - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.4.0-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- >-
|
||||
cd ~
|
||||
&& git clone https://github.com/ray-project/ray || true
|
||||
- >-
|
||||
cd ~/ray/python
|
||||
&& pip install -e . --verbose
|
||||
|
||||
# Custom commands that will be run on the head node after common setup.
|
||||
head_setup_commands: []
|
||||
|
||||
# Custom commands that will be run on worker nodes after common setup.
|
||||
worker_setup_commands: []
|
||||
|
||||
# Command to start ray on the head node. You don't need to change this.
|
||||
head_start_ray_commands:
|
||||
- ray stop
|
||||
- >-
|
||||
ulimit -n 65536;
|
||||
ray start
|
||||
--head
|
||||
--redis-port=6379
|
||||
--object-manager-port=8076
|
||||
--autoscaling-config=~/ray_bootstrap_config.yaml
|
||||
|
||||
# Command to start ray on worker nodes. You don't need to change this.
|
||||
worker_start_ray_commands:
|
||||
- ray stop
|
||||
- >-
|
||||
ulimit -n 65536;
|
||||
ray start
|
||||
--redis-address=$RAY_HEAD_IP:6379
|
||||
--object-manager-port=8076
|
17
python/ray/autoscaler/gcp/example-minimal.yaml
Normal file
17
python/ray/autoscaler/gcp/example-minimal.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: minimal
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
max_workers: 1
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: gcp
|
||||
region: us-west1
|
||||
availability_zone: us-west1-a
|
||||
project_id: null # Globally unique project id
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
213
python/ray/autoscaler/gcp/node_provider.py
Normal file
213
python/ray/autoscaler/gcp/node_provider.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from uuid import uuid4
|
||||
import time
|
||||
|
||||
from googleapiclient import discovery
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL
|
||||
|
||||
INSTANCE_NAME_MAX_LEN = 64
|
||||
INSTANCE_NAME_UUID_LEN = 8
|
||||
|
||||
|
||||
def wait_for_compute_zone_operation(compute, project_name, operation, zone):
|
||||
"""Poll for compute zone operation until finished."""
|
||||
print("Waiting for operation {} to finish...".format(operation["name"]))
|
||||
|
||||
for _ in range(MAX_POLLS):
|
||||
result = compute.zoneOperations().get(
|
||||
project=project_name, operation=operation["name"],
|
||||
zone=zone).execute()
|
||||
if "error" in result:
|
||||
raise Exception(result["error"])
|
||||
|
||||
if result["status"] == "DONE":
|
||||
print("Done.")
|
||||
break
|
||||
|
||||
time.sleep(POLL_INTERVAL)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class GCPNodeProvider(NodeProvider):
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
|
||||
self.compute = discovery.build("compute", "v1")
|
||||
|
||||
# Cache of node objects from the last nodes() call. This avoids
|
||||
# excessive DescribeInstances requests.
|
||||
self.cached_nodes = {}
|
||||
|
||||
# Cache of ip lookups. We assume IPs never change once assigned.
|
||||
self.internal_ip_cache = {}
|
||||
self.external_ip_cache = {}
|
||||
|
||||
def nodes(self, tag_filters):
|
||||
if tag_filters:
|
||||
label_filter_expr = "(" + " AND ".join([
|
||||
"(labels.{key} = {value})".format(key=key, value=value)
|
||||
for key, value in tag_filters.items()
|
||||
]) + ")"
|
||||
else:
|
||||
label_filter_expr = ""
|
||||
|
||||
instance_state_filter_expr = "(" + " OR ".join([
|
||||
"(status = {status})".format(status=status)
|
||||
for status in {"PROVISIONING", "STAGING", "RUNNING"}
|
||||
]) + ")"
|
||||
|
||||
cluster_name_filter_expr = ("(labels.{key} = {value})"
|
||||
"".format(
|
||||
key=TAG_RAY_CLUSTER_NAME,
|
||||
value=self.cluster_name))
|
||||
|
||||
not_empty_filters = [
|
||||
f for f in [
|
||||
label_filter_expr,
|
||||
instance_state_filter_expr,
|
||||
cluster_name_filter_expr,
|
||||
] if f
|
||||
]
|
||||
|
||||
filter_expr = " AND ".join(not_empty_filters)
|
||||
|
||||
response = self.compute.instances().list(
|
||||
project=self.provider_config["project_id"],
|
||||
zone=self.provider_config["availability_zone"],
|
||||
filter=filter_expr,
|
||||
).execute()
|
||||
|
||||
instances = response.get("items", [])
|
||||
# Note: All the operations use "name" as the unique instance identifier
|
||||
self.cached_nodes = {i["name"]: i for i in instances}
|
||||
|
||||
return [i["name"] for i in instances]
|
||||
|
||||
def is_running(self, node_id):
|
||||
node = self._node(node_id)
|
||||
return node["status"] == "RUNNING"
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
node = self._node(node_id)
|
||||
return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"}
|
||||
|
||||
def node_tags(self, node_id):
|
||||
node = self._node(node_id)
|
||||
labels = node.get("labels", {})
|
||||
return labels
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
labels = tags
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
node = self._node(node_id)
|
||||
operation = self.compute.instances().setLabels(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
instance=node_id,
|
||||
body={
|
||||
"labels": dict(node["labels"], **labels),
|
||||
"labelFingerprint": node["labelFingerprint"]
|
||||
}).execute()
|
||||
|
||||
result = wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
|
||||
return result
|
||||
|
||||
def external_ip(self, node_id):
|
||||
if node_id in self.external_ip_cache:
|
||||
return self.external_ip_cache[node_id]
|
||||
node = self._node(node_id)
|
||||
# TODO: Is there a better and more reliable way to do this?
|
||||
ip = (node.get("networkInterfaces", [{}])[0].get(
|
||||
"accessConfigs", [{}])[0].get("natIP", None))
|
||||
if ip:
|
||||
self.external_ip_cache[node_id] = ip
|
||||
return ip
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
if node_id in self.internal_ip_cache:
|
||||
return self.internal_ip_cache[node_id]
|
||||
node = self._node(node_id)
|
||||
ip = node.get("networkInterfaces", [{}])[0].get("networkIP")
|
||||
if ip:
|
||||
self.internal_ip_cache[node_id] = ip
|
||||
return ip
|
||||
|
||||
def create_node(self, base_config, tags, count):
|
||||
labels = tags # gcp uses "labels" instead of aws "tags"
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
config = base_config.copy()
|
||||
|
||||
name_label = labels[TAG_RAY_NODE_NAME]
|
||||
assert (len(name_label) <=
|
||||
(INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), (
|
||||
name_label, len(name_label))
|
||||
|
||||
config.update({
|
||||
"machineType": ("zones/{zone}/machineTypes/{machine_type}"
|
||||
"".format(
|
||||
zone=availability_zone,
|
||||
machine_type=base_config["machineType"])),
|
||||
"labels": dict(
|
||||
config.get("labels", {}), **labels,
|
||||
**{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
|
||||
})
|
||||
|
||||
operations = [
|
||||
self.compute.instances().insert(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
body=dict(
|
||||
config, **{
|
||||
"name": ("{name_label}-{uuid}".format(
|
||||
name_label=name_label,
|
||||
uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN]))
|
||||
})).execute() for i in range(count)
|
||||
]
|
||||
|
||||
results = [
|
||||
wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
for operation in operations
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
operation = self.compute.instances().delete(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
instance=node_id,
|
||||
).execute()
|
||||
|
||||
result = wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
|
||||
return result
|
||||
|
||||
def _node(self, node_id):
|
||||
if node_id in self.cached_nodes:
|
||||
return self.cached_nodes[node_id]
|
||||
|
||||
instance = self.compute.instances().get(
|
||||
project=self.provider_config["project_id"],
|
||||
zone=self.provider_config["availability_zone"],
|
||||
instance=node_id,
|
||||
).execute()
|
||||
|
||||
return instance
|
|
@ -13,11 +13,22 @@ def import_aws():
|
|||
return bootstrap_aws, AWSNodeProvider
|
||||
|
||||
|
||||
def load_aws_config():
|
||||
def import_gcp():
|
||||
from ray.autoscaler.gcp.config import bootstrap_gcp
|
||||
from ray.autoscaler.gcp.node_provider import GCPNodeProvider
|
||||
return bootstrap_gcp, GCPNodeProvider
|
||||
|
||||
|
||||
def load_aws_example_config():
|
||||
import ray.autoscaler.aws as ray_aws
|
||||
return os.path.join(os.path.dirname(ray_aws.__file__), "example-full.yaml")
|
||||
|
||||
|
||||
def load_gcp_example_config():
|
||||
import ray.autoscaler.gcp as ray_gcp
|
||||
return os.path.join(os.path.dirname(ray_gcp.__file__), "example-full.yaml")
|
||||
|
||||
|
||||
def import_external():
|
||||
"""Mock a normal provider importer."""
|
||||
|
||||
|
@ -29,8 +40,8 @@ def import_external():
|
|||
|
||||
NODE_PROVIDERS = {
|
||||
"aws": import_aws,
|
||||
"gce": None, # TODO: support more node providers
|
||||
"azure": None,
|
||||
"gcp": import_gcp,
|
||||
"azure": None, # TODO: support more node providers
|
||||
"kubernetes": None,
|
||||
"docker": None,
|
||||
"local_cluster": None,
|
||||
|
@ -38,9 +49,9 @@ NODE_PROVIDERS = {
|
|||
}
|
||||
|
||||
DEFAULT_CONFIGS = {
|
||||
"aws": load_aws_config,
|
||||
"gce": None, # TODO: support more node providers
|
||||
"azure": None,
|
||||
"aws": load_aws_example_config,
|
||||
"gcp": load_gcp_example_config,
|
||||
"azure": None, # TODO: support more node providers
|
||||
"kubernetes": None,
|
||||
"docker": None,
|
||||
"local_cluster": None,
|
||||
|
@ -115,7 +126,7 @@ class NodeProvider(object):
|
|||
nodes() must be called again to refresh results.
|
||||
|
||||
Examples:
|
||||
>>> provider.nodes({TAG_RAY_NODE_TYPE: "Worker"})
|
||||
>>> provider.nodes({TAG_RAY_NODE_TYPE: "worker"})
|
||||
["node-1", "node-2"]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
"""The Ray autoscaler uses tags/labels to associate metadata with instances."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""The Ray autoscaler uses tags to associate metadata with instances."""
|
||||
|
||||
# Tag for the name of the node
|
||||
TAG_NAME = "Name"
|
||||
|
||||
# Tag uniquely identifying all nodes of a cluster
|
||||
TAG_RAY_CLUSTER_NAME = "ray:ClusterName"
|
||||
TAG_RAY_NODE_NAME = "ray-node-name"
|
||||
|
||||
# Tag for the type of node (e.g. Head, Worker)
|
||||
TAG_RAY_NODE_TYPE = "ray:NodeType"
|
||||
TAG_RAY_NODE_TYPE = "ray-node-type"
|
||||
|
||||
# Tag that reports the current state of the node (e.g. Updating, Up-to-date)
|
||||
TAG_RAY_NODE_STATUS = "ray:NodeStatus"
|
||||
TAG_RAY_NODE_STATUS = "ray-node-status"
|
||||
|
||||
# Tag uniquely identifying all nodes of a cluster
|
||||
TAG_RAY_CLUSTER_NAME = "ray-cluster-name"
|
||||
|
||||
# Hash of the node launch config, used to identify out-of-date nodes
|
||||
TAG_RAY_LAUNCH_CONFIG = "ray:LaunchConfig"
|
||||
TAG_RAY_LAUNCH_CONFIG = "ray-launch-config"
|
||||
|
||||
# Hash of the node runtime config, used to determine if updates are needed
|
||||
TAG_RAY_RUNTIME_CONFIG = "ray:RuntimeConfig"
|
||||
TAG_RAY_RUNTIME_CONFIG = "ray-runtime-config"
|
||||
|
|
|
@ -2,7 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pipes
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
except ImportError: # py2
|
||||
from pipes import quote
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -17,6 +20,7 @@ from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
|
|||
|
||||
# How long to wait for a node to start, in seconds
|
||||
NODE_START_WAIT_S = 300
|
||||
SSH_CHECK_INTERVAL = 5
|
||||
|
||||
|
||||
def pretty_cmd(cmd_str):
|
||||
|
@ -43,7 +47,10 @@ class NodeUpdater(object):
|
|||
self.ssh_user = auth_config["ssh_user"]
|
||||
self.ssh_ip = self.provider.external_ip(node_id)
|
||||
self.node_id = node_id
|
||||
self.file_mounts = file_mounts
|
||||
self.file_mounts = {
|
||||
remote: os.path.expanduser(local)
|
||||
for remote, local in file_mounts.items()
|
||||
}
|
||||
self.setup_cmds = setup_cmds
|
||||
self.runtime_hash = runtime_hash
|
||||
if redirect_output:
|
||||
|
@ -73,7 +80,7 @@ class NodeUpdater(object):
|
|||
"See {} for remote logs.".format(error_str, self.output_name),
|
||||
file=self.stdout)
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "UpdateFailed"})
|
||||
{TAG_RAY_NODE_STATUS: "update-failed"})
|
||||
if self.logfile is not None:
|
||||
print("----- BEGIN REMOTE LOGS -----\n" +
|
||||
open(self.logfile.name).read() +
|
||||
|
@ -81,7 +88,7 @@ class NodeUpdater(object):
|
|||
raise e
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {
|
||||
TAG_RAY_NODE_STATUS: "Up-to-date",
|
||||
TAG_RAY_NODE_STATUS: "up-to-date",
|
||||
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
|
||||
})
|
||||
print(
|
||||
|
@ -91,7 +98,7 @@ class NodeUpdater(object):
|
|||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "WaitingForSSH"})
|
||||
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
|
||||
# Wait for external IP
|
||||
|
@ -130,20 +137,20 @@ class NodeUpdater(object):
|
|||
print(
|
||||
"NodeUpdater: SSH not up, retrying: {}".format(retry_str),
|
||||
file=self.stdout)
|
||||
time.sleep(5)
|
||||
time.sleep(SSH_CHECK_INTERVAL)
|
||||
else:
|
||||
break
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
# Rsync file mounts
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "SyncingFiles"})
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
print(
|
||||
"NodeUpdater: Syncing {} to {}...".format(
|
||||
local_path, remote_path),
|
||||
file=self.stdout)
|
||||
assert os.path.exists(local_path)
|
||||
assert os.path.exists(local_path), local_path
|
||||
if os.path.isdir(local_path):
|
||||
if not local_path.endswith("/"):
|
||||
local_path += "/"
|
||||
|
@ -162,7 +169,7 @@ class NodeUpdater(object):
|
|||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "SettingUp"})
|
||||
{TAG_RAY_NODE_STATUS: "setting-up"})
|
||||
for cmd in self.setup_cmds:
|
||||
self.ssh_cmd(cmd, verbose=True)
|
||||
|
||||
|
@ -172,14 +179,13 @@ class NodeUpdater(object):
|
|||
"NodeUpdater: running {} on {}...".format(
|
||||
pretty_cmd(cmd), self.ssh_ip),
|
||||
file=self.stdout)
|
||||
force_interactive = "set -i && source ~/.bashrc && "
|
||||
force_interactive = "set -i || true && source ~/.bashrc && "
|
||||
self.process_runner.check_call(
|
||||
[
|
||||
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", self.ssh_private_key, "{}@{}".format(
|
||||
self.ssh_user, self.ssh_ip), "bash --login -c {}".format(
|
||||
pipes.quote(force_interactive + cmd))
|
||||
"-o", "StrictHostKeyChecking=no", "-i", self.ssh_private_key,
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip),
|
||||
"bash --login -c {}".format(quote(force_interactive + cmd))
|
||||
],
|
||||
stdout=redirect or self.stdout,
|
||||
stderr=redirect or self.stderr)
|
||||
|
|
|
@ -4,10 +4,14 @@ from __future__ import print_function
|
|||
|
||||
import distutils.spawn
|
||||
import os
|
||||
import pipes
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
except ImportError: # py2
|
||||
from pipes import quote
|
||||
|
||||
import ray
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.error import TuneError
|
||||
|
@ -16,14 +20,29 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
|
|||
# Map from (logdir, remote_dir) -> syncer
|
||||
_syncers = {}
|
||||
|
||||
S3_PREFIX = "s3://"
|
||||
GCS_PREFIX = "gs://"
|
||||
ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX)
|
||||
|
||||
|
||||
def get_syncer(local_dir, remote_dir=None):
|
||||
if remote_dir:
|
||||
if not remote_dir.startswith("s3://"):
|
||||
raise TuneError("Upload uri must start with s3://")
|
||||
if not any(
|
||||
remote_dir.startswith(prefix)
|
||||
for prefix in ALLOWED_REMOTE_PREFIXES):
|
||||
raise TuneError("Upload uri must start with one of: {}"
|
||||
"".format(ALLOWED_REMOTE_PREFIXES))
|
||||
|
||||
if not distutils.spawn.find_executable("aws"):
|
||||
raise TuneError("Upload uri requires awscli tool to be installed")
|
||||
if (remote_dir.startswith(S3_PREFIX)
|
||||
and not distutils.spawn.find_executable("aws")):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires awscli tool"
|
||||
" to be installed".format(S3_PREFIX))
|
||||
elif (remote_dir.startswith(GCS_PREFIX)
|
||||
and not distutils.spawn.find_executable("gsutil")):
|
||||
raise TuneError(
|
||||
"Upload uri starting with '{}' requires gsutil tool"
|
||||
" to be installed".format(GCS_PREFIX))
|
||||
|
||||
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
|
||||
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
|
||||
|
@ -85,14 +104,18 @@ class _LogSyncer(object):
|
|||
print("Error: log sync requires rsync to be installed.")
|
||||
return
|
||||
worker_to_local_sync_cmd = ((
|
||||
"""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
|
||||
"""rsync -avz -e "ssh -i {} -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
|
||||
ssh_key, ssh_user, self.worker_ip,
|
||||
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
|
||||
quote(ssh_key), ssh_user, self.worker_ip,
|
||||
quote(self.local_dir), quote(self.local_dir)))
|
||||
|
||||
if self.remote_dir:
|
||||
local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format(
|
||||
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
|
||||
if self.remote_dir.startswith(S3_PREFIX):
|
||||
local_to_remote_sync_cmd = ("aws s3 sync {} {}".format(
|
||||
quote(self.local_dir), quote(self.remote_dir)))
|
||||
elif self.remote_dir.startswith(GCS_PREFIX):
|
||||
local_to_remote_sync_cmd = ("gsutil rsync -r {} {}".format(
|
||||
quote(self.local_dir), quote(self.remote_dir)))
|
||||
else:
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
|
|
|
@ -40,7 +40,10 @@ ray_ui_files = [
|
|||
"ray/core/src/catapult_files/trace_viewer_full.html"
|
||||
]
|
||||
|
||||
ray_autoscaler_files = ["ray/autoscaler/aws/example-full.yaml"]
|
||||
ray_autoscaler_files = [
|
||||
"ray/autoscaler/aws/example-full.yaml",
|
||||
"ray/autoscaler/gcp/example-full.yaml",
|
||||
]
|
||||
|
||||
if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
|
||||
ray_files += [
|
||||
|
|
|
@ -251,7 +251,7 @@ class AutoscalingTest(unittest.TestCase):
|
|||
config["max_workers"] = 5
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider()
|
||||
self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "Worker"}, 10)
|
||||
self.provider.create_node({}, {TAG_RAY_NODE_TYPE: "worker"}, 10)
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path, LoadMetrics(), max_failures=0, update_interval_s=0)
|
||||
self.assertEqual(len(self.provider.nodes({})), 10)
|
||||
|
@ -398,12 +398,12 @@ class AutoscalingTest(unittest.TestCase):
|
|||
node.state = "running"
|
||||
assert len(
|
||||
self.provider.nodes({
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized"
|
||||
TAG_RAY_NODE_STATUS: "uninitialized"
|
||||
})) == 2
|
||||
autoscaler.update()
|
||||
self.waitFor(
|
||||
lambda: len(self.provider.nodes(
|
||||
{TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2)
|
||||
{TAG_RAY_NODE_STATUS: "up-to-date"})) == 2)
|
||||
|
||||
def testReportsConfigFailures(self):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
|
@ -424,12 +424,12 @@ class AutoscalingTest(unittest.TestCase):
|
|||
node.state = "running"
|
||||
assert len(
|
||||
self.provider.nodes({
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized"
|
||||
TAG_RAY_NODE_STATUS: "uninitialized"
|
||||
})) == 2
|
||||
autoscaler.update()
|
||||
self.waitFor(
|
||||
lambda: len(self.provider.nodes(
|
||||
{TAG_RAY_NODE_STATUS: "UpdateFailed"})) == 2)
|
||||
{TAG_RAY_NODE_STATUS: "update-failed"})) == 2)
|
||||
|
||||
def testConfiguresOutdatedNodes(self):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
|
@ -451,7 +451,7 @@ class AutoscalingTest(unittest.TestCase):
|
|||
autoscaler.update()
|
||||
self.waitFor(
|
||||
lambda: len(self.provider.nodes(
|
||||
{TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2)
|
||||
{TAG_RAY_NODE_STATUS: "up-to-date"})) == 2)
|
||||
runner.calls = []
|
||||
new_config = SMALL_CLUSTER.copy()
|
||||
new_config["worker_setup_commands"] = ["cmdX", "cmdY"]
|
||||
|
@ -520,7 +520,7 @@ class AutoscalingTest(unittest.TestCase):
|
|||
autoscaler.update()
|
||||
self.waitFor(
|
||||
lambda: len(self.provider.nodes(
|
||||
{TAG_RAY_NODE_STATUS: "Up-to-date"})) == 2)
|
||||
{TAG_RAY_NODE_STATUS: "up-to-date"})) == 2)
|
||||
|
||||
# Mark a node as unhealthy
|
||||
lm.last_heartbeat_time_by_ip["172.0.0.0"] = 0
|
||||
|
|
Loading…
Add table
Reference in a new issue