mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Autoscaler] Command Line Interface improvements (#9322)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
456e012029
commit
908c0c630a
11 changed files with 1475 additions and 332 deletions
|
@ -13,7 +13,11 @@ import botocore
|
|||
|
||||
from ray.ray_constants import BOTO_MAX_RETRIES
|
||||
from ray.autoscaler.tags import NODE_TYPE_WORKER, NODE_TYPE_HEAD
|
||||
from ray.autoscaler.aws.utils import LazyDefaultDict
|
||||
from ray.autoscaler.aws.utils import LazyDefaultDict, handle_boto_error
|
||||
from ray.autoscaler.node_provider import PROVIDER_PRETTY_NAMES
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
import colorful as cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,7 +49,8 @@ DEFAULT_AMI = {
|
|||
"sa-east-1": "ami-0da2c49fe75e7e5ed", # SA (Sao Paulo)
|
||||
}
|
||||
|
||||
|
||||
# todo: cli_logger should handle this assert properly
|
||||
# this should probably also happens somewhere else
|
||||
assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
|
||||
"Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
|
||||
|
||||
|
@ -70,6 +75,112 @@ def key_pair(i, region, key_name):
|
|||
# Suppress excessive connection dropped logs from boto
|
||||
logging.getLogger("botocore").setLevel(logging.WARNING)
|
||||
|
||||
_log_info = {}
|
||||
|
||||
|
||||
def reload_log_state(override_log_info):
|
||||
_log_info.update(override_log_info)
|
||||
|
||||
|
||||
def get_log_state():
|
||||
return _log_info.copy()
|
||||
|
||||
|
||||
def _set_config_info(**kwargs):
|
||||
"""Record configuration artifacts useful for logging."""
|
||||
|
||||
# todo: this is technically fragile iff we ever use multiple configs
|
||||
|
||||
for k, v in kwargs.items():
|
||||
_log_info[k] = v
|
||||
|
||||
|
||||
def _arn_to_name(arn):
|
||||
return arn.split(":")[-1].split("/")[-1]
|
||||
|
||||
|
||||
def log_to_cli(config):
|
||||
provider_name = PROVIDER_PRETTY_NAMES.get("aws", None)
|
||||
|
||||
cli_logger.doassert(provider_name is not None,
|
||||
"Could not find a pretty name for the AWS provider.")
|
||||
|
||||
with cli_logger.group("{} config", provider_name):
|
||||
|
||||
def same_everywhere(key):
|
||||
return config["head_node"][key] == config["worker_nodes"][key]
|
||||
|
||||
def print_info(resource_string,
|
||||
key,
|
||||
head_src_key,
|
||||
workers_src_key,
|
||||
allowed_tags=["default"],
|
||||
list_value=False):
|
||||
|
||||
head_tags = {}
|
||||
workers_tags = {}
|
||||
|
||||
if _log_info[head_src_key] in allowed_tags:
|
||||
head_tags[_log_info[head_src_key]] = True
|
||||
if _log_info[workers_src_key] in allowed_tags:
|
||||
workers_tags[_log_info[workers_src_key]] = True
|
||||
|
||||
head_value_str = config["head_node"][key]
|
||||
if list_value:
|
||||
head_value_str = cli_logger.render_list(head_value_str)
|
||||
|
||||
if same_everywhere(key):
|
||||
cli_logger.labeled_value( # todo: handle plural vs singular?
|
||||
resource_string + " (head & workers)",
|
||||
"{}",
|
||||
head_value_str,
|
||||
_tags=head_tags)
|
||||
else:
|
||||
workers_value_str = config["worker_nodes"][key]
|
||||
if list_value:
|
||||
workers_value_str = cli_logger.render_list(
|
||||
workers_value_str)
|
||||
|
||||
cli_logger.labeled_value(
|
||||
resource_string + " (head)",
|
||||
"{}",
|
||||
head_value_str,
|
||||
_tags=head_tags)
|
||||
cli_logger.labeled_value(
|
||||
resource_string + " (workers)",
|
||||
"{}",
|
||||
workers_value_str,
|
||||
_tags=workers_tags)
|
||||
|
||||
tags = {"default": _log_info["head_instance_profile_src"] == "default"}
|
||||
cli_logger.labeled_value(
|
||||
"IAM Profile",
|
||||
"{}",
|
||||
_arn_to_name(config["head_node"]["IamInstanceProfile"]["Arn"]),
|
||||
_tags=tags)
|
||||
|
||||
print_info("EC2 Key pair", "KeyName", "keypair_src", "keypair_src")
|
||||
print_info(
|
||||
"VPC Subnets",
|
||||
"SubnetIds",
|
||||
"head_subnet_src",
|
||||
"workers_subnet_src",
|
||||
list_value=True)
|
||||
print_info(
|
||||
"EC2 Security groups",
|
||||
"SecurityGroupIds",
|
||||
"head_security_group_src",
|
||||
"workers_security_group_src",
|
||||
list_value=True)
|
||||
print_info(
|
||||
"EC2 AMI",
|
||||
"ImageId",
|
||||
"head_ami_src",
|
||||
"workers_ami_src",
|
||||
allowed_tags=["dlami"])
|
||||
|
||||
cli_logger.newline()
|
||||
|
||||
|
||||
def bootstrap_aws(config):
|
||||
# The head node needs to have an IAM role that allows it to create further
|
||||
|
@ -94,27 +205,38 @@ def bootstrap_aws(config):
|
|||
|
||||
def _configure_iam_role(config):
|
||||
if "IamInstanceProfile" in config["head_node"]:
|
||||
_set_config_info(head_instance_profile_src="config")
|
||||
return config
|
||||
_set_config_info(head_instance_profile_src="default")
|
||||
|
||||
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
|
||||
|
||||
if profile is None:
|
||||
logger.info("_configure_iam_role: "
|
||||
"Creating new instance profile {}".format(
|
||||
DEFAULT_RAY_INSTANCE_PROFILE))
|
||||
cli_logger.verbose(
|
||||
"Creating new IAM instance profile {} for use as the default.",
|
||||
cf.bold(DEFAULT_RAY_INSTANCE_PROFILE))
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_iam_role: "
|
||||
"Creating new instance profile {}", DEFAULT_RAY_INSTANCE_PROFILE)
|
||||
client = _client("iam", config)
|
||||
client.create_instance_profile(
|
||||
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
|
||||
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
cli_logger.doassert(profile is not None,
|
||||
"Failed to create instance profile.") # todo: err msg
|
||||
assert profile is not None, "Failed to create instance profile"
|
||||
|
||||
if not profile.roles:
|
||||
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
|
||||
if role is None:
|
||||
logger.info("_configure_iam_role: "
|
||||
"Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
|
||||
cli_logger.verbose(
|
||||
"Creating new IAM role {} for "
|
||||
"use as the default instance role.",
|
||||
cf.bold(DEFAULT_RAY_IAM_ROLE))
|
||||
cli_logger.old_info(logger, "_configure_iam_role: "
|
||||
"Creating new role {}", DEFAULT_RAY_IAM_ROLE)
|
||||
iam = _resource("iam", config)
|
||||
iam.create_role(
|
||||
RoleName=DEFAULT_RAY_IAM_ROLE,
|
||||
|
@ -130,6 +252,9 @@ def _configure_iam_role(config):
|
|||
],
|
||||
}))
|
||||
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
|
||||
|
||||
cli_logger.doassert(role is not None,
|
||||
"Failed to create role.") # todo: err msg
|
||||
assert role is not None, "Failed to create role"
|
||||
role.attach_policy(
|
||||
PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
|
||||
|
@ -138,9 +263,9 @@ def _configure_iam_role(config):
|
|||
profile.add_role(RoleName=role.name)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
logger.info("_configure_iam_role: "
|
||||
"Role not specified for head node, using {}".format(
|
||||
profile.arn))
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_iam_role: "
|
||||
"Role not specified for head node, using {}", profile.arn)
|
||||
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
|
||||
|
||||
return config
|
||||
|
@ -148,9 +273,19 @@ def _configure_iam_role(config):
|
|||
|
||||
def _configure_key_pair(config):
|
||||
if "ssh_private_key" in config["auth"]:
|
||||
_set_config_info(keypair_src="config")
|
||||
|
||||
cli_logger.doassert( # todo: verify schema beforehand?
|
||||
"KeyName" in config["head_node"],
|
||||
"`KeyName` missing for head node.") # todo: err msg
|
||||
cli_logger.doassert(
|
||||
"KeyName" in config["worker_nodes"],
|
||||
"`KeyName` missing for worker nodes.") # todo: err msg
|
||||
|
||||
assert "KeyName" in config["head_node"]
|
||||
assert "KeyName" in config["worker_nodes"]
|
||||
return config
|
||||
_set_config_info(keypair_src="default")
|
||||
|
||||
ec2 = _resource("ec2", config)
|
||||
|
||||
|
@ -170,8 +305,12 @@ def _configure_key_pair(config):
|
|||
|
||||
# We can safely create a new key.
|
||||
if not key and not os.path.exists(key_path):
|
||||
logger.info("_configure_key_pair: "
|
||||
"Creating new key pair {}".format(key_name))
|
||||
cli_logger.verbose(
|
||||
"Creating new key pair {} for use as the default.",
|
||||
cf.bold(key_name))
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_key_pair: "
|
||||
"Creating new key pair {}", key_name)
|
||||
key = ec2.create_key_pair(KeyName=key_name)
|
||||
|
||||
# We need to make sure to _create_ the file with the right
|
||||
|
@ -182,16 +321,25 @@ def _configure_key_pair(config):
|
|||
break
|
||||
|
||||
if not key:
|
||||
cli_logger.abort(
|
||||
"No matching local key file for any of the key pairs in this "
|
||||
"account with ids from 0..{}. "
|
||||
"Consider deleting some unused keys pairs from your account.",
|
||||
key_name) # todo: err msg
|
||||
raise ValueError(
|
||||
"No matching local key file for any of the key pairs in this "
|
||||
"account with ids from 0..{}. ".format(key_name) +
|
||||
"Consider deleting some unused keys pairs from your account.")
|
||||
|
||||
cli_logger.doassert(
|
||||
os.path.exists(key_path), "Private key file " + cf.bold("{}") +
|
||||
" not found for " + cf.bold("{}"), key_path, key_name) # todo: err msg
|
||||
assert os.path.exists(key_path), \
|
||||
"Private key file {} not found for {}".format(key_path, key_name)
|
||||
|
||||
logger.info("_configure_key_pair: "
|
||||
"KeyName not specified for nodes, using {}".format(key_name))
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_key_pair: "
|
||||
"KeyName not specified for nodes, using {}", key_name)
|
||||
|
||||
config["auth"]["ssh_private_key"] = key_path
|
||||
config["head_node"]["KeyName"] = key_name
|
||||
|
@ -203,12 +351,25 @@ def _configure_key_pair(config):
|
|||
def _configure_subnet(config):
|
||||
ec2 = _resource("ec2", config)
|
||||
use_internal_ips = config["provider"].get("use_internal_ips", False)
|
||||
subnets = sorted(
|
||||
(s for s in ec2.subnets.all() if s.state == "available" and (
|
||||
use_internal_ips or s.map_public_ip_on_launch)),
|
||||
reverse=True, # sort from Z-A
|
||||
key=lambda subnet: subnet.availability_zone)
|
||||
|
||||
try:
|
||||
subnets = sorted(
|
||||
(s for s in ec2.subnets.all() if s.state == "available" and (
|
||||
use_internal_ips or s.map_public_ip_on_launch)),
|
||||
reverse=True, # sort from Z-A
|
||||
key=lambda subnet: subnet.availability_zone)
|
||||
except botocore.exceptions.ClientError as exc:
|
||||
handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
|
||||
raise exc
|
||||
|
||||
if not subnets:
|
||||
cli_logger.abort(
|
||||
"No usable subnets found, try manually creating an instance in "
|
||||
"your specified region to populate the list of subnets "
|
||||
"and trying this again.\n"
|
||||
"Note that the subnet must map public IPs "
|
||||
"on instance launch unless you set `use_internal_ips: true` in "
|
||||
"the `provider` config.") # todo: err msg
|
||||
raise Exception(
|
||||
"No usable subnets found, try manually creating an instance in "
|
||||
"your specified region to populate the list of subnets "
|
||||
|
@ -219,6 +380,12 @@ def _configure_subnet(config):
|
|||
azs = config["provider"]["availability_zone"].split(",")
|
||||
subnets = [s for s in subnets if s.availability_zone in azs]
|
||||
if not subnets:
|
||||
cli_logger.abort(
|
||||
"No usable subnets matching availability zone {} found.\n"
|
||||
"Choose a different availability zone or try "
|
||||
"manually creating an instance in your specified region "
|
||||
"to populate the list of subnets and trying this again.",
|
||||
config["provider"]["availability_zone"]) # todo: err msg
|
||||
raise Exception(
|
||||
"No usable subnets matching availability zone {} "
|
||||
"found. Choose a different availability zone or try "
|
||||
|
@ -229,21 +396,31 @@ def _configure_subnet(config):
|
|||
subnet_ids = [s.subnet_id for s in subnets]
|
||||
subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets]
|
||||
if "SubnetIds" not in config["head_node"]:
|
||||
_set_config_info(head_subnet_src="default")
|
||||
config["head_node"]["SubnetIds"] = subnet_ids
|
||||
logger.info("_configure_subnet: "
|
||||
"SubnetIds not specified for head node, using {}".format(
|
||||
subnet_descr))
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_subnet: "
|
||||
"SubnetIds not specified for head node, using {}", subnet_descr)
|
||||
else:
|
||||
_set_config_info(head_subnet_src="config")
|
||||
|
||||
if "SubnetIds" not in config["worker_nodes"]:
|
||||
_set_config_info(workers_subnet_src="default")
|
||||
config["worker_nodes"]["SubnetIds"] = subnet_ids
|
||||
logger.info("_configure_subnet: "
|
||||
"SubnetId not specified for workers,"
|
||||
" using {}".format(subnet_descr))
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_subnet: "
|
||||
"SubnetId not specified for workers,"
|
||||
" using {}", subnet_descr)
|
||||
else:
|
||||
_set_config_info(workers_subnet_src="config")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_security_group(config):
|
||||
_set_config_info(
|
||||
head_security_group_src="config", workers_security_group_src="config")
|
||||
|
||||
node_types_to_configure = [
|
||||
node_type for node_type, config_key in NODE_TYPE_CONFIG_KEYS.items()
|
||||
if "SecurityGroupIds" not in config[NODE_TYPE_CONFIG_KEYS[node_type]]
|
||||
|
@ -255,17 +432,22 @@ def _configure_security_group(config):
|
|||
|
||||
if NODE_TYPE_HEAD in node_types_to_configure:
|
||||
head_sg = security_groups[NODE_TYPE_HEAD]
|
||||
logger.info(
|
||||
"_configure_security_group: "
|
||||
"SecurityGroupIds not specified for head node, using {} ({})"
|
||||
.format(head_sg.group_name, head_sg.id))
|
||||
|
||||
_set_config_info(head_security_group_src="default")
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_security_group: "
|
||||
"SecurityGroupIds not specified for head node, using {} ({})",
|
||||
head_sg.group_name, head_sg.id)
|
||||
config["head_node"]["SecurityGroupIds"] = [head_sg.id]
|
||||
|
||||
if NODE_TYPE_WORKER in node_types_to_configure:
|
||||
workers_sg = security_groups[NODE_TYPE_WORKER]
|
||||
logger.info("_configure_security_group: "
|
||||
"SecurityGroupIds not specified for workers, using {} ({})"
|
||||
.format(workers_sg.group_name, workers_sg.id))
|
||||
|
||||
_set_config_info(workers_security_group_src="default")
|
||||
cli_logger.old_info(
|
||||
logger, "_configure_security_group: "
|
||||
"SecurityGroupIds not specified for workers, using {} ({})",
|
||||
workers_sg.group_name, workers_sg.id)
|
||||
config["worker_nodes"]["SecurityGroupIds"] = [workers_sg.id]
|
||||
|
||||
return config
|
||||
|
@ -274,6 +456,8 @@ def _configure_security_group(config):
|
|||
def _check_ami(config):
|
||||
"""Provide helpful message for missing ImageId for node configuration."""
|
||||
|
||||
_set_config_info(head_ami_src="config", workers_ami_src="config")
|
||||
|
||||
region = config["provider"]["region"]
|
||||
default_ami = DEFAULT_AMI.get(region)
|
||||
if not default_ami:
|
||||
|
@ -282,21 +466,27 @@ def _check_ami(config):
|
|||
|
||||
if config["head_node"].get("ImageId", "").lower() == "latest_dlami":
|
||||
config["head_node"]["ImageId"] = default_ami
|
||||
logger.info("_check_ami: head node ImageId is 'latest_dlami'. "
|
||||
"Using '{ami_id}', which is the default {ami_name} "
|
||||
"for your region ({region}).".format(
|
||||
ami_id=default_ami,
|
||||
ami_name=DEFAULT_AMI_NAME,
|
||||
region=region))
|
||||
_set_config_info(head_ami_src="dlami")
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"_check_ami: head node ImageId is 'latest_dlami'. "
|
||||
"Using '{ami_id}', which is the default {ami_name} "
|
||||
"for your region ({region}).",
|
||||
ami_id=default_ami,
|
||||
ami_name=DEFAULT_AMI_NAME,
|
||||
region=region)
|
||||
|
||||
if config["worker_nodes"].get("ImageId", "").lower() == "latest_dlami":
|
||||
config["worker_nodes"]["ImageId"] = default_ami
|
||||
logger.info("_check_ami: worker nodes ImageId is 'latest_dlami'. "
|
||||
"Using '{ami_id}', which is the default {ami_name} "
|
||||
"for your region ({region}).".format(
|
||||
ami_id=default_ami,
|
||||
ami_name=DEFAULT_AMI_NAME,
|
||||
region=region))
|
||||
_set_config_info(workers_ami_src="dlami")
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"_check_ami: worker nodes ImageId is 'latest_dlami'. "
|
||||
"Using '{ami_id}', which is the default {ami_name} "
|
||||
"for your region ({region}).",
|
||||
ami_id=default_ami,
|
||||
ami_name=DEFAULT_AMI_NAME,
|
||||
region=region)
|
||||
|
||||
|
||||
def _upsert_security_groups(config, node_types):
|
||||
|
@ -350,6 +540,9 @@ def _get_vpc_id_or_die(ec2, subnet_id):
|
|||
"Name": "subnet-id",
|
||||
"Values": [subnet_id]
|
||||
}]))
|
||||
|
||||
# TODO: better error message
|
||||
cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id)
|
||||
assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id)
|
||||
subnet = subnet[0]
|
||||
return subnet.vpc_id
|
||||
|
@ -383,8 +576,17 @@ def _create_security_group(config, vpc_id, group_name):
|
|||
GroupName=group_name,
|
||||
VpcId=vpc_id)
|
||||
security_group = _get_security_group(config, vpc_id, group_name)
|
||||
logger.info("_create_security_group: Created new security group {} ({})"
|
||||
.format(security_group.group_name, security_group.id))
|
||||
|
||||
cli_logger.verbose(
|
||||
"Created new security group {}",
|
||||
cf.bold(security_group.group_name),
|
||||
_tags=dict(id=security_group.id))
|
||||
cli_logger.old_info(
|
||||
logger, "_create_security_group: Created new security group {} ({})",
|
||||
security_group.group_name, security_group.id)
|
||||
|
||||
cli_logger.doassert(security_group,
|
||||
"Failed to create security group") # err msg
|
||||
assert security_group, "Failed to create security group"
|
||||
return security_group
|
||||
|
||||
|
@ -454,6 +656,9 @@ def _get_role(role_name, config):
|
|||
if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
||||
return None
|
||||
else:
|
||||
handle_boto_error(
|
||||
exc, "Failed to fetch IAM role data for {} from AWS.",
|
||||
cf.bold(role_name))
|
||||
raise exc
|
||||
|
||||
|
||||
|
@ -467,17 +672,26 @@ def _get_instance_profile(profile_name, config):
|
|||
if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
|
||||
return None
|
||||
else:
|
||||
handle_boto_error(
|
||||
exc,
|
||||
"Failed to fetch IAM instance profile data for {} from AWS.",
|
||||
cf.bold(profile_name))
|
||||
raise exc
|
||||
|
||||
|
||||
def _get_key(key_name, config):
|
||||
ec2 = _resource("ec2", config)
|
||||
for key in ec2.key_pairs.filter(Filters=[{
|
||||
"Name": "key-name",
|
||||
"Values": [key_name]
|
||||
}]):
|
||||
if key.name == key_name:
|
||||
return key
|
||||
try:
|
||||
for key in ec2.key_pairs.filter(Filters=[{
|
||||
"Name": "key-name",
|
||||
"Values": [key_name]
|
||||
}]):
|
||||
if key.name == key_name:
|
||||
return key
|
||||
except botocore.exceptions.ClientError as exc:
|
||||
handle_boto_error(exc, "Failed to fetch EC2 key pair {} from AWS.",
|
||||
cf.bold(key_name))
|
||||
raise exc
|
||||
|
||||
|
||||
def _client(name, config):
|
||||
|
|
|
@ -15,6 +15,10 @@ from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
|
|||
from ray.ray_constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
from ray.autoscaler.aws.utils import boto_exception_handler
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
import colorful as cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -133,7 +137,10 @@ class AWSNodeProvider(NodeProvider):
|
|||
"Values": [v],
|
||||
})
|
||||
|
||||
nodes = list(self.ec2.instances.filter(Filters=filters))
|
||||
with boto_exception_handler(
|
||||
"Failed to fetch running instances from AWS."):
|
||||
nodes = list(self.ec2.instances.filter(Filters=filters))
|
||||
|
||||
# Populate the tag cache with initial information if necessary
|
||||
for node in nodes:
|
||||
if node.id in self.tag_cache:
|
||||
|
@ -228,19 +235,32 @@ class AWSNodeProvider(NodeProvider):
|
|||
self.ec2.instances.filter(Filters=filters))[:count]
|
||||
reuse_node_ids = [n.id for n in reuse_nodes]
|
||||
if reuse_nodes:
|
||||
logger.info("AWSNodeProvider: reusing instances {}. "
|
||||
"To disable reuse, set "
|
||||
"'cache_stopped_nodes: False' in the provider "
|
||||
"config.".format(reuse_node_ids))
|
||||
cli_logger.print(
|
||||
# todo: handle plural vs singular?
|
||||
"Reusing nodes {}. "
|
||||
"To disable reuse, set `cache_stopped_nodes: False` "
|
||||
"under `provider` in the cluster configuration.",
|
||||
cli_logger.render_list(reuse_node_ids))
|
||||
cli_logger.old_info(
|
||||
logger, "AWSNodeProvider: reusing instances {}. "
|
||||
"To disable reuse, set "
|
||||
"'cache_stopped_nodes: False' in the provider "
|
||||
"config.", reuse_node_ids)
|
||||
|
||||
for node in reuse_nodes:
|
||||
self.tag_cache[node.id] = from_aws_format(
|
||||
{x["Key"]: x["Value"]
|
||||
for x in node.tags})
|
||||
if node.state["Name"] == "stopping":
|
||||
logger.info("AWSNodeProvider: waiting for instance "
|
||||
"{} to fully stop...".format(node.id))
|
||||
node.wait_until_stopped()
|
||||
# todo: timed?
|
||||
with cli_logger.group("Stopping instances to reuse"):
|
||||
for node in reuse_nodes:
|
||||
self.tag_cache[node.id] = from_aws_format(
|
||||
{x["Key"]: x["Value"]
|
||||
for x in node.tags})
|
||||
if node.state["Name"] == "stopping":
|
||||
cli_logger.print("Waiting for instance {} to stop",
|
||||
node.id)
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"AWSNodeProvider: waiting for instance "
|
||||
"{} to fully stop...", node.id)
|
||||
node.wait_until_stopped()
|
||||
|
||||
self.ec2.meta.client.start_instances(
|
||||
InstanceIds=reuse_node_ids)
|
||||
|
@ -300,8 +320,11 @@ class AWSNodeProvider(NodeProvider):
|
|||
for attempt in range(1, BOTO_CREATE_MAX_RETRIES + 1):
|
||||
try:
|
||||
subnet_id = subnet_ids[self.subnet_idx % len(subnet_ids)]
|
||||
logger.info("NodeProvider: calling create_instances "
|
||||
"with {} (count={}).".format(subnet_id, count))
|
||||
|
||||
cli_logger.old_info(
|
||||
logger, "NodeProvider: calling create_instances "
|
||||
"with {} (count={}).", subnet_id, count)
|
||||
|
||||
self.subnet_idx += 1
|
||||
conf.update({
|
||||
"MinCount": 1,
|
||||
|
@ -310,32 +333,64 @@ class AWSNodeProvider(NodeProvider):
|
|||
"TagSpecifications": tag_specs
|
||||
})
|
||||
created = self.ec2_fail_fast.create_instances(**conf)
|
||||
for instance in created:
|
||||
logger.info("NodeProvider: Created instance "
|
||||
"[id={}, name={}, info={}]".format(
|
||||
instance.instance_id,
|
||||
instance.state["Name"],
|
||||
instance.state_reason["Message"]))
|
||||
|
||||
# todo: timed?
|
||||
# todo: handle plurality?
|
||||
with cli_logger.group(
|
||||
"Launching {} nodes",
|
||||
count,
|
||||
_tags=dict(subnet_id=subnet_id)):
|
||||
for instance in created:
|
||||
cli_logger.print(
|
||||
"Launched instance {}",
|
||||
instance.instance_id,
|
||||
_tags=dict(
|
||||
state=instance.state["Name"],
|
||||
info=instance.state_reason["Message"]))
|
||||
cli_logger.old_info(
|
||||
logger, "NodeProvider: Created instance "
|
||||
"[id={}, name={}, info={}]", instance.instance_id,
|
||||
instance.state["Name"],
|
||||
instance.state_reason["Message"])
|
||||
break
|
||||
except botocore.exceptions.ClientError as exc:
|
||||
if attempt == BOTO_CREATE_MAX_RETRIES:
|
||||
logger.error(
|
||||
"create_instances: Max attempts ({}) exceeded.".format(
|
||||
BOTO_CREATE_MAX_RETRIES))
|
||||
# todo: err msg
|
||||
cli_logger.abort(
|
||||
"Failed to launch instances. Max attempts exceeded.")
|
||||
cli_logger.old_error(
|
||||
logger,
|
||||
"create_instances: Max attempts ({}) exceeded.",
|
||||
BOTO_CREATE_MAX_RETRIES)
|
||||
raise exc
|
||||
else:
|
||||
logger.error(exc)
|
||||
# todo: err msg
|
||||
cli_logger.abort(exc)
|
||||
cli_logger.old_error(logger, exc)
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
if self.cache_stopped_nodes:
|
||||
if node.spot_instance_request_id:
|
||||
logger.info(
|
||||
cli_logger.print(
|
||||
"Terminating instance {} " +
|
||||
cf.gray("(cannot stop spot instances, only terminate)"),
|
||||
node_id) # todo: show node name?
|
||||
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"AWSNodeProvider: terminating node {} (spot nodes cannot "
|
||||
"be stopped, only terminated)".format(node_id))
|
||||
"be stopped, only terminated)", node_id)
|
||||
node.terminate()
|
||||
else:
|
||||
logger.info(
|
||||
cli_logger.print("Stopping instance {} " + cf.gray(
|
||||
"(to terminate instead, "
|
||||
"set `cache_stopped_nodes: False` "
|
||||
"under `provider` in the cluster configuration)"),
|
||||
node_id) # todo: show node name?
|
||||
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"AWSNodeProvider: stopping node {}. To terminate nodes "
|
||||
"on stop, set 'cache_stopped_nodes: False' in the "
|
||||
"provider config.".format(node_id))
|
||||
|
@ -360,15 +415,30 @@ class AWSNodeProvider(NodeProvider):
|
|||
on_demand_ids += [node_id]
|
||||
|
||||
if on_demand_ids:
|
||||
logger.info(
|
||||
# todo: show node names?
|
||||
cli_logger.print(
|
||||
"Stopping instances {} " + cf.gray(
|
||||
"(to terminate instead, "
|
||||
"set `cache_stopped_nodes: False` "
|
||||
"under `provider` in the cluster configuration)"),
|
||||
cli_logger.render_list(on_demand_ids))
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"AWSNodeProvider: stopping nodes {}. To terminate nodes "
|
||||
"on stop, set 'cache_stopped_nodes: False' in the "
|
||||
"provider config.".format(on_demand_ids))
|
||||
"provider config.", on_demand_ids)
|
||||
|
||||
self.ec2.meta.client.stop_instances(InstanceIds=on_demand_ids)
|
||||
if spot_ids:
|
||||
logger.info(
|
||||
cli_logger.print(
|
||||
"Terminating instances {} " +
|
||||
cf.gray("(cannot stop spot instances, only terminate)"),
|
||||
cli_logger.render_list(spot_ids))
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"AWSNodeProvider: terminating nodes {} (spot nodes cannot "
|
||||
"be stopped, only terminated)".format(spot_ids))
|
||||
"be stopped, only terminated)", spot_ids)
|
||||
|
||||
self.ec2.meta.client.terminate_instances(InstanceIds=spot_ids)
|
||||
else:
|
||||
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
import colorful as cf
|
||||
|
||||
|
||||
class LazyDefaultDict(defaultdict):
|
||||
"""
|
||||
|
@ -21,3 +24,99 @@ class LazyDefaultDict(defaultdict):
|
|||
"""
|
||||
self[key] = self.default_factory(key)
|
||||
return self[key]
|
||||
|
||||
|
||||
def handle_boto_error(exc, msg, *args, **kwargs):
|
||||
if cli_logger.old_style:
|
||||
# old-style logging doesn't do anything here
|
||||
# so we exit early
|
||||
return
|
||||
|
||||
error_code = None
|
||||
error_info = None
|
||||
# todo: not sure if these exceptions always have response
|
||||
if hasattr(exc, "response"):
|
||||
error_info = exc.response.get("Error", None)
|
||||
if error_info is not None:
|
||||
error_code = error_info.get("Code", None)
|
||||
|
||||
generic_message_args = [
|
||||
"{}\n"
|
||||
"Error code: {}",
|
||||
msg.format(*args, **kwargs),
|
||||
cf.bold(error_code)
|
||||
]
|
||||
|
||||
# apparently
|
||||
# ExpiredTokenException
|
||||
# ExpiredToken
|
||||
# RequestExpired
|
||||
# are all the same pretty much
|
||||
credentials_expiration_codes = [
|
||||
"ExpiredTokenException", "ExpiredToken", "RequestExpired"
|
||||
]
|
||||
|
||||
if error_code in credentials_expiration_codes:
|
||||
# "An error occurred (ExpiredToken) when calling the
|
||||
# GetInstanceProfile operation: The security token
|
||||
# included in the request is expired"
|
||||
|
||||
# "An error occurred (RequestExpired) when calling the
|
||||
# DescribeKeyPairs operation: Request has expired."
|
||||
|
||||
token_command = (
|
||||
"aws sts get-session-token "
|
||||
"--serial-number arn:aws:iam::" + cf.underlined("ROOT_ACCOUNT_ID")
|
||||
+ ":mfa/" + cf.underlined("AWS_USERNAME") + " --token-code " +
|
||||
cf.underlined("TWO_FACTOR_AUTH_CODE"))
|
||||
|
||||
secret_key_var = (
|
||||
"export AWS_SECRET_ACCESS_KEY = " + cf.underlined("REPLACE_ME") +
|
||||
" # found at Credentials.SecretAccessKey")
|
||||
session_token_var = (
|
||||
"export AWS_SESSION_TOKEN = " + cf.underlined("REPLACE_ME") +
|
||||
" # found at Credentials.SessionToken")
|
||||
access_key_id_var = (
|
||||
"export AWS_ACCESS_KEY_ID = " + cf.underlined("REPLACE_ME") +
|
||||
" # found at Credentials.AccessKeyId")
|
||||
|
||||
# fixme: replace with a Github URL that points
|
||||
# to our repo
|
||||
aws_session_script_url = ("https://gist.github.com/maximsmol/"
|
||||
"a0284e1d97b25d417bd9ae02e5f450cf")
|
||||
|
||||
cli_logger.verbose_error(*generic_message_args)
|
||||
cli_logger.verbose(vars(exc))
|
||||
|
||||
cli_logger.abort(
|
||||
"Your AWS session has expired.\n\n"
|
||||
"You can request a new one using\n{}\n"
|
||||
"then expose it to Ray by setting\n{}\n{}\n{}\n\n"
|
||||
"You can find a script that automates this at:\n{}",
|
||||
cf.bold(token_command), cf.bold(secret_key_var),
|
||||
cf.bold(session_token_var), cf.bold(access_key_id_var),
|
||||
cf.underlined(aws_session_script_url))
|
||||
|
||||
# todo: any other errors that we should catch separately?
|
||||
|
||||
cli_logger.error(*generic_message_args)
|
||||
cli_logger.newline()
|
||||
with cli_logger.verbatim_error_ctx("Boto3 error:"):
|
||||
cli_logger.verbose(vars(exc))
|
||||
cli_logger.error(exc)
|
||||
cli_logger.abort()
|
||||
|
||||
|
||||
def boto_exception_handler(msg, *args, **kwargs):
|
||||
# todo: implement timer
|
||||
class ExceptionHandlerContextManager():
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
import botocore
|
||||
|
||||
if type is botocore.exceptions.ClientError:
|
||||
handle_boto_error(value, msg, *args, **kwargs)
|
||||
|
||||
return ExceptionHandlerContextManager()
|
||||
|
|
299
python/ray/autoscaler/cli_logger.py
Normal file
299
python/ray/autoscaler/cli_logger.py
Normal file
|
@ -0,0 +1,299 @@
|
|||
import sys
|
||||
|
||||
import click
|
||||
|
||||
import colorama
|
||||
from colorful.core import ColorfulString
|
||||
import colorful as cf
|
||||
colorama.init()
|
||||
|
||||
|
||||
def _strip_codes(msg):
|
||||
return msg # todo
|
||||
|
||||
|
||||
# we could bold "{}" strings automatically but do we want that?
|
||||
# todo:
|
||||
def _format_msg(msg,
|
||||
*args,
|
||||
_tags=None,
|
||||
_numbered=None,
|
||||
_no_format=None,
|
||||
**kwargs):
|
||||
if isinstance(msg, str) or isinstance(msg, ColorfulString):
|
||||
tags_str = ""
|
||||
if _tags is not None:
|
||||
tags_list = []
|
||||
for k, v in _tags.items():
|
||||
if v is True:
|
||||
tags_list += [k]
|
||||
continue
|
||||
if v is False:
|
||||
continue
|
||||
|
||||
tags_list += [k + "=" + v]
|
||||
if tags_list:
|
||||
tags_str = cf.reset(
|
||||
cf.gray(" [{}]".format(", ".join(tags_list))))
|
||||
|
||||
numbering_str = ""
|
||||
if _numbered is not None:
|
||||
chars, i, n = _numbered
|
||||
|
||||
i = str(i)
|
||||
n = str(n)
|
||||
|
||||
numbering_str = cf.gray(chars[0] + i + "/" + n + chars[1]) + " "
|
||||
|
||||
if _no_format:
|
||||
# todo: throw if given args/kwargs?
|
||||
return numbering_str + msg + tags_str
|
||||
return numbering_str + msg.format(*args, **kwargs) + tags_str
|
||||
|
||||
if kwargs:
|
||||
raise ValueError("We do not support printing kwargs yet.")
|
||||
|
||||
res = [msg, *args]
|
||||
res = [str(x) for x in res]
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
class _CliLogger():
|
||||
def __init__(self):
|
||||
self.strip = False
|
||||
self.old_style = True
|
||||
self.color_mode = "auto"
|
||||
self.indent_level = 0
|
||||
self.verbosity = 0
|
||||
|
||||
self.dump_command_output = False
|
||||
|
||||
self.info = {}
|
||||
|
||||
def detect_colors(self):
|
||||
if self.color_mode == "true":
|
||||
self.strip = False
|
||||
return
|
||||
if self.color_mode == "false":
|
||||
self.strip = True
|
||||
return
|
||||
if self.color_mode == "auto":
|
||||
self.strip = sys.stdout.isatty()
|
||||
return
|
||||
|
||||
raise ValueError("Invalid log color setting: " + self.color_mode)
|
||||
|
||||
def newline(self):
|
||||
self._print("")
|
||||
|
||||
def _print(self, msg, linefeed=True):
|
||||
if self.old_style:
|
||||
return
|
||||
|
||||
if self.strip:
|
||||
msg = _strip_codes(msg)
|
||||
|
||||
rendered_message = " " * self.indent_level + msg
|
||||
|
||||
if not linefeed:
|
||||
sys.stdout.write(rendered_message)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
|
||||
print(rendered_message)
|
||||
|
||||
def indented(self, cls=False):
|
||||
cli_logger = self
|
||||
|
||||
class IndentedContextManager():
|
||||
def __enter__(self):
|
||||
cli_logger.indent_level += 1
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
cli_logger.indent_level -= 1
|
||||
|
||||
if cls:
|
||||
# fixme: this does not work :()
|
||||
return IndentedContextManager
|
||||
return IndentedContextManager()
|
||||
|
||||
def timed(self, msg, *args, **kwargs):
|
||||
return self.group(msg, *args, **kwargs)
|
||||
|
||||
def group(self, msg, *args, **kwargs):
|
||||
self._print(_format_msg(cf.cornflowerBlue(msg), *args, **kwargs))
|
||||
|
||||
return self.indented()
|
||||
|
||||
def verbatim_error_ctx(self, msg, *args, **kwargs):
|
||||
cli_logger = self
|
||||
|
||||
class VerbatimErorContextManager():
|
||||
def __enter__(self):
|
||||
cli_logger.error(cf.bold("!!! ") + msg, *args, **kwargs)
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
cli_logger.error(cf.bold("!!!"))
|
||||
|
||||
return VerbatimErorContextManager()
|
||||
|
||||
def labeled_value(self, key, msg, *args, **kwargs):
|
||||
self._print(
|
||||
cf.cyan(key) + ": " + _format_msg(cf.bold(msg), *args, **kwargs))
|
||||
|
||||
def verbose(self, msg, *args, **kwargs):
|
||||
if self.verbosity > 0:
|
||||
self.print(msg, *args, **kwargs)
|
||||
|
||||
def verbose_error(self, msg, *args, **kwargs):
|
||||
if self.verbosity > 0:
|
||||
self.error(msg, *args, **kwargs)
|
||||
|
||||
def very_verbose(self, msg, *args, **kwargs):
|
||||
if self.verbosity > 1:
|
||||
self.print(msg, *args, **kwargs)
|
||||
|
||||
def success(self, msg, *args, **kwargs):
|
||||
self._print(_format_msg(cf.green(msg), *args, **kwargs))
|
||||
|
||||
def warning(self, msg, *args, **kwargs):
|
||||
self._print(_format_msg(cf.yellow(msg), *args, **kwargs))
|
||||
|
||||
def error(self, msg, *args, **kwargs):
|
||||
self._print(_format_msg(cf.red(msg), *args, **kwargs))
|
||||
|
||||
def print(self, msg, *args, **kwargs):
|
||||
self._print(_format_msg(msg, *args, **kwargs))
|
||||
|
||||
def abort(self, msg=None, *args, **kwargs):
|
||||
if msg is not None:
|
||||
self.error(msg, *args, **kwargs)
|
||||
|
||||
raise SilentClickException("Exiting due to cli_logger.abort()")
|
||||
|
||||
def doassert(self, val, msg, *args, **kwargs):
|
||||
if self.old_style:
|
||||
return
|
||||
|
||||
if not val:
|
||||
self.abort(msg, *args, **kwargs)
|
||||
|
||||
def old_debug(self, logger, msg, *args, **kwargs):
|
||||
if self.old_style:
|
||||
logger.debug(_format_msg(msg, *args, **kwargs))
|
||||
return
|
||||
|
||||
def old_info(self, logger, msg, *args, **kwargs):
|
||||
if self.old_style:
|
||||
logger.info(_format_msg(msg, *args, **kwargs))
|
||||
return
|
||||
|
||||
def old_warning(self, logger, msg, *args, **kwargs):
|
||||
if self.old_style:
|
||||
logger.warning(_format_msg(msg, *args, **kwargs))
|
||||
return
|
||||
|
||||
def old_error(self, logger, msg, *args, **kwargs):
|
||||
if self.old_style:
|
||||
logger.error(_format_msg(msg, *args, **kwargs))
|
||||
return
|
||||
|
||||
def old_exception(self, logger, msg, *args, **kwargs):
|
||||
if self.old_style:
|
||||
logger.exception(_format_msg(msg, *args, **kwargs))
|
||||
return
|
||||
|
||||
def render_list(self, xs, separator=cf.reset(", ")):
|
||||
return separator.join([str(cf.bold(x)) for x in xs])
|
||||
|
||||
def confirm(self, yes, msg, *args, _abort=False, _default=False, **kwargs):
|
||||
if self.old_style:
|
||||
return
|
||||
|
||||
should_abort = _abort
|
||||
default = _default
|
||||
|
||||
if default:
|
||||
yn_str = cf.green("Y") + "/" + cf.red("n")
|
||||
else:
|
||||
yn_str = cf.green("y") + "/" + cf.red("N")
|
||||
|
||||
confirm_str = cf.underlined("Confirm [" + yn_str + "]:") + " "
|
||||
|
||||
rendered_message = _format_msg(msg, *args, **kwargs)
|
||||
if rendered_message and rendered_message[-1] != "\n":
|
||||
rendered_message += " "
|
||||
|
||||
msg_len = len(rendered_message.split("\n")[-1])
|
||||
complete_str = rendered_message + confirm_str
|
||||
|
||||
if yes:
|
||||
self._print(complete_str + "y " +
|
||||
cf.gray("[automatic, due to --yes]"))
|
||||
return True
|
||||
|
||||
self._print(complete_str, linefeed=False)
|
||||
|
||||
res = None
|
||||
yes_answers = ["y", "yes", "true", "1"]
|
||||
no_answers = ["n", "no", "false", "0"]
|
||||
try:
|
||||
while True:
|
||||
ans = sys.stdin.readline()
|
||||
ans = ans.lower()
|
||||
|
||||
if ans == "\n":
|
||||
res = default
|
||||
break
|
||||
|
||||
ans = ans.strip()
|
||||
if ans in yes_answers:
|
||||
res = True
|
||||
break
|
||||
if ans in no_answers:
|
||||
res = False
|
||||
break
|
||||
|
||||
indent = " " * msg_len
|
||||
self.error("{}Invalid answer: {}. "
|
||||
"Expected {} or {}", indent, cf.bold(ans.strip()),
|
||||
self.render_list(yes_answers, "/"),
|
||||
self.render_list(no_answers, "/"))
|
||||
self._print(indent + confirm_str, linefeed=False)
|
||||
except KeyboardInterrupt:
|
||||
self.newline()
|
||||
res = default
|
||||
|
||||
if not res and should_abort:
|
||||
# todo: make sure we tell the user if they
|
||||
# need to do cleanup
|
||||
self._print("Exiting...")
|
||||
raise SilentClickException(
|
||||
"Exiting due to the response to confirm(should_abort=True).")
|
||||
|
||||
return res
|
||||
|
||||
def old_confirm(self, msg, yes):
|
||||
if not self.old_style:
|
||||
return
|
||||
|
||||
return None if yes else click.confirm(msg, abort=True)
|
||||
|
||||
|
||||
class SilentClickException(click.ClickException):
|
||||
"""
|
||||
Some of our tooling relies on catching ClickException in particular.
|
||||
|
||||
However the default prints a message, which is undesirable since we expect
|
||||
our code to log errors manually using `cli_logger.error()` to allow for
|
||||
colors and other formatting.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
super(SilentClickException, self).__init__(message)
|
||||
|
||||
def show(self, file=None):
|
||||
pass
|
||||
|
||||
|
||||
cli_logger = _CliLogger()
|
|
@ -12,6 +12,9 @@ import time
|
|||
from ray.autoscaler.docker import check_docker_running_cmd, with_docker_exec
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
import colorful as cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long to wait for a node to start, in seconds
|
||||
|
@ -21,6 +24,24 @@ KUBECTL_RSYNC = os.path.join(
|
|||
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
|
||||
|
||||
|
||||
class ProcessRunnerError(Exception):
|
||||
def __init__(self,
|
||||
msg,
|
||||
msg_type,
|
||||
code=None,
|
||||
command=None,
|
||||
message_discovered=None):
|
||||
super(ProcessRunnerError, self).__init__(
|
||||
"{} (discovered={}): type={}, code={}, command={}".format(
|
||||
msg, message_discovered, msg_type, code, command))
|
||||
|
||||
self.msg_type = msg_type
|
||||
self.code = code
|
||||
self.command = command
|
||||
|
||||
self.message_discovered = message_discovered
|
||||
|
||||
|
||||
def _with_interactive(cmd):
|
||||
force_interactive = ("true && source ~/.bashrc && "
|
||||
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
|
||||
|
@ -256,14 +277,27 @@ class SSHCommandRunner(CommandRunnerInterface):
|
|||
else:
|
||||
return self.provider.external_ip(self.node_id)
|
||||
|
||||
def _wait_for_ip(self, deadline):
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
logger.info(self.log_prefix + "Waiting for IP...")
|
||||
ip = self._get_node_ip()
|
||||
if ip is not None:
|
||||
return ip
|
||||
time.sleep(10)
|
||||
def wait_for_ip(self, deadline):
|
||||
# if we have IP do not print waiting info
|
||||
ip = self._get_node_ip()
|
||||
if ip is not None:
|
||||
cli_logger.labeled_value("Fetched IP", ip)
|
||||
return ip
|
||||
|
||||
interval = 10
|
||||
with cli_logger.timed("Waiting for IP"):
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
cli_logger.old_info(logger, "{}Waiting for IP...",
|
||||
self.log_prefix)
|
||||
|
||||
ip = self._get_node_ip()
|
||||
if ip is not None:
|
||||
cli_logger.labeled_value("Received", ip)
|
||||
return ip
|
||||
cli_logger.print("Not yet available, retrying in {} seconds",
|
||||
cf.bold(str(interval)))
|
||||
time.sleep(interval)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -275,7 +309,10 @@ class SSHCommandRunner(CommandRunnerInterface):
|
|||
# I think that's reasonable.
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
with LogTimer(self.log_prefix + "Got IP"):
|
||||
ip = self._wait_for_ip(deadline)
|
||||
ip = self.wait_for_ip(deadline)
|
||||
|
||||
cli_logger.doassert(ip is not None,
|
||||
"Could not get node IP.") # todo: msg
|
||||
assert ip is not None, "Unable to find IP of node"
|
||||
|
||||
self.ssh_ip = ip
|
||||
|
@ -286,7 +323,8 @@ class SSHCommandRunner(CommandRunnerInterface):
|
|||
try:
|
||||
os.makedirs(self.ssh_control_path, mode=0o700, exist_ok=True)
|
||||
except OSError as e:
|
||||
logger.warning(e)
|
||||
cli_logger.warning(e) # todo: msg
|
||||
cli_logger.old_warning(logger, e)
|
||||
|
||||
def run(self,
|
||||
cmd,
|
||||
|
@ -308,59 +346,91 @@ class SSHCommandRunner(CommandRunnerInterface):
|
|||
ssh = ["ssh", "-tt"]
|
||||
|
||||
if port_forward:
|
||||
if not isinstance(port_forward, list):
|
||||
port_forward = [port_forward]
|
||||
for local, remote in port_forward:
|
||||
logger.info(self.log_prefix + "Forwarding " +
|
||||
"{} -> localhost:{}".format(local, remote))
|
||||
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
|
||||
with cli_logger.group("Forwarding ports"):
|
||||
if not isinstance(port_forward, list):
|
||||
port_forward = [port_forward]
|
||||
for local, remote in port_forward:
|
||||
cli_logger.verbose(
|
||||
"Forwarding port {} to port {} on localhost.",
|
||||
cf.bold(local), cf.bold(remote)) # todo: msg
|
||||
cli_logger.old_info(logger,
|
||||
"{}Forwarding {} -> localhost:{}",
|
||||
self.log_prefix, local, remote)
|
||||
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
|
||||
|
||||
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip)
|
||||
]
|
||||
if cmd:
|
||||
final_cmd += _with_interactive(cmd)
|
||||
logger.info(self.log_prefix +
|
||||
"Running {}".format(" ".join(final_cmd)))
|
||||
cli_logger.old_info(logger, "{}Running {}", self.log_prefix,
|
||||
" ".join(final_cmd))
|
||||
else:
|
||||
# We do this because `-o ControlMaster` causes the `-N` flag to
|
||||
# still create an interactive shell in some ssh versions.
|
||||
final_cmd.append(quote("while true; do sleep 86400; done"))
|
||||
|
||||
try:
|
||||
if with_output:
|
||||
return self.process_runner.check_output(final_cmd)
|
||||
else:
|
||||
self.process_runner.check_call(final_cmd)
|
||||
except subprocess.CalledProcessError:
|
||||
if exit_on_fail:
|
||||
# todo: add a flag for this, we might
|
||||
# wanna log commands with print sometimes
|
||||
cli_logger.verbose("Running `{}`", cf.bold(cmd))
|
||||
with cli_logger.indented():
|
||||
cli_logger.very_verbose("Full command is `{}`",
|
||||
cf.bold(" ".join(final_cmd)))
|
||||
|
||||
def start_process():
|
||||
try:
|
||||
if with_output:
|
||||
return self.process_runner.check_output(final_cmd)
|
||||
else:
|
||||
self.process_runner.check_call(final_cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
quoted_cmd = " ".join(final_cmd[:-1] + [quote(final_cmd[-1])])
|
||||
raise click.ClickException(
|
||||
"Command failed: \n\n {}\n".format(quoted_cmd)) from None
|
||||
else:
|
||||
raise click.ClickException(
|
||||
"SSH command Failed. See above for the output from the"
|
||||
" failure.") from None
|
||||
if not cli_logger.old_style:
|
||||
raise ProcessRunnerError(
|
||||
"Command failed",
|
||||
"ssh_command_failed",
|
||||
code=e.returncode,
|
||||
command=quoted_cmd)
|
||||
|
||||
if exit_on_fail:
|
||||
raise click.ClickException(
|
||||
"Command failed: \n\n {}\n".format(quoted_cmd)) \
|
||||
from None
|
||||
else:
|
||||
raise click.ClickException(
|
||||
"SSH command Failed. See above for the output from the"
|
||||
" failure.") from None
|
||||
|
||||
if cli_logger.verbosity > 0:
|
||||
with cli_logger.indented():
|
||||
start_process()
|
||||
else:
|
||||
start_process()
|
||||
|
||||
def run_rsync_up(self, source, target):
|
||||
self._set_ssh_ip_if_required()
|
||||
self.process_runner.check_call([
|
||||
command = [
|
||||
"rsync", "--rsh",
|
||||
subprocess.list2cmdline(
|
||||
["ssh"] + self.ssh_options.to_ssh_options_list(timeout=120)),
|
||||
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
|
||||
target)
|
||||
])
|
||||
]
|
||||
cli_logger.verbose("Running `{}`", cf.bold(" ".join(command)))
|
||||
self.process_runner.check_call(command)
|
||||
|
||||
def run_rsync_down(self, source, target):
|
||||
self._set_ssh_ip_if_required()
|
||||
self.process_runner.check_call([
|
||||
|
||||
command = [
|
||||
"rsync", "--rsh",
|
||||
subprocess.list2cmdline(
|
||||
["ssh"] + self.ssh_options.to_ssh_options_list(timeout=120)),
|
||||
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
|
||||
source), target
|
||||
])
|
||||
]
|
||||
cli_logger.verbose("Running `{}`", cf.bold(" ".join(command)))
|
||||
self.process_runner.check_call(command)
|
||||
|
||||
def remote_shell_command_str(self):
|
||||
return "ssh -o IdentitiesOnly=yes -i {} {}@{}\n".format(
|
||||
|
|
|
@ -21,7 +21,9 @@ import ray.services as services
|
|||
from ray.autoscaler.util import validate_config, hash_runtime_conf, \
|
||||
hash_launch_conf, prepare_config, DEBUG_AUTOSCALING_ERROR, \
|
||||
DEBUG_AUTOSCALING_STATUS
|
||||
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
|
||||
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS, \
|
||||
PROVIDER_PRETTY_NAMES, try_get_log_state, try_logging_config, \
|
||||
try_reload_log_state
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
||||
TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, NODE_TYPE_HEAD
|
||||
|
||||
|
@ -31,6 +33,9 @@ from ray.autoscaler.command_runner import DockerCommandRunner
|
|||
from ray.autoscaler.log_timer import LogTimer
|
||||
from ray.worker import global_worker
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
import colorful as cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
redis_client = None
|
||||
|
@ -89,20 +94,93 @@ def create_or_update_cluster(
|
|||
config_file: str, override_min_workers: Optional[int],
|
||||
override_max_workers: Optional[int], no_restart: bool,
|
||||
restart_only: bool, yes: bool, override_cluster_name: Optional[str],
|
||||
no_config_cache: bool) -> None:
|
||||
no_config_cache: bool, log_old_style: bool, log_color: str,
|
||||
verbose: int) -> None:
|
||||
"""Create or updates an autoscaling Ray cluster from a config json."""
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
if override_min_workers is not None:
|
||||
config["min_workers"] = override_min_workers
|
||||
if override_max_workers is not None:
|
||||
config["max_workers"] = override_max_workers
|
||||
if override_cluster_name is not None:
|
||||
config["cluster_name"] = override_cluster_name
|
||||
cli_logger.old_style = log_old_style
|
||||
cli_logger.color_mode = log_color
|
||||
cli_logger.verbosity = verbose
|
||||
|
||||
# todo: disable by default when the command output handling PR makes it in
|
||||
cli_logger.dump_command_output = True
|
||||
|
||||
cli_logger.detect_colors()
|
||||
|
||||
def handle_yaml_error(e):
|
||||
cli_logger.error(
|
||||
"Cluster config invalid.\n"
|
||||
"Failed to load YAML file " + cf.bold("{}"), config_file)
|
||||
cli_logger.newline()
|
||||
with cli_logger.verbatim_error_ctx("PyYAML error:"):
|
||||
cli_logger.error(e)
|
||||
cli_logger.abort()
|
||||
|
||||
try:
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
except FileNotFoundError:
|
||||
cli_logger.abort(
|
||||
"Provided cluster configuration file ({}) does not exist.",
|
||||
cf.bold(config_file))
|
||||
except yaml.parser.ParserError as e:
|
||||
handle_yaml_error(e)
|
||||
except yaml.scanner.ScannerError as e:
|
||||
handle_yaml_error(e)
|
||||
|
||||
# todo: validate file_mounts, ssh keys, etc.
|
||||
|
||||
importer = NODE_PROVIDERS.get(config["provider"]["type"])
|
||||
if not importer:
|
||||
cli_logger.abort(
|
||||
"Unknown provider type " + cf.bold("{}") + "\n"
|
||||
"Available providers are: {}", config["provider"]["type"],
|
||||
cli_logger.render_list([
|
||||
k for k in NODE_PROVIDERS.keys()
|
||||
if NODE_PROVIDERS[k] is not None
|
||||
]))
|
||||
raise NotImplementedError("Unsupported provider {}".format(
|
||||
config["provider"]))
|
||||
|
||||
cli_logger.success("Cluster configuration valid.\n")
|
||||
|
||||
printed_overrides = False
|
||||
|
||||
def handle_cli_override(key, override):
|
||||
if override is not None:
|
||||
if key in config:
|
||||
nonlocal printed_overrides
|
||||
printed_overrides = True
|
||||
cli_logger.warning(
|
||||
"`{}` override provided on the command line.\n"
|
||||
" Using " + cf.bold("{}") + cf.dimmed(
|
||||
" [configuration file has " + cf.bold("{}") + "]"),
|
||||
key, override, config[key])
|
||||
config[key] = override
|
||||
|
||||
handle_cli_override("min_workers", override_min_workers)
|
||||
handle_cli_override("max_workers", override_max_workers)
|
||||
handle_cli_override("cluster_name", override_cluster_name)
|
||||
|
||||
if printed_overrides:
|
||||
cli_logger.newline()
|
||||
|
||||
cli_logger.labeled_value("Cluster", config["cluster_name"])
|
||||
|
||||
# disable the cli_logger here if needed
|
||||
# because it only supports aws
|
||||
if config["provider"]["type"] != "aws":
|
||||
cli_logger.old_style = True
|
||||
config = _bootstrap_config(config, no_config_cache)
|
||||
if config["provider"]["type"] != "aws":
|
||||
cli_logger.old_style = False
|
||||
|
||||
try_logging_config(config)
|
||||
get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
override_cluster_name)
|
||||
|
||||
|
||||
CONFIG_CACHE_VERSION = 1
|
||||
|
||||
|
||||
def _bootstrap_config(config: Dict[str, Any],
|
||||
no_config_cache: bool = False) -> Dict[str, Any]:
|
||||
config = prepare_config(config)
|
||||
|
@ -111,9 +189,30 @@ def _bootstrap_config(config: Dict[str, Any],
|
|||
hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
|
||||
cache_key = os.path.join(tempfile.gettempdir(),
|
||||
"ray-config-{}".format(hasher.hexdigest()))
|
||||
|
||||
if os.path.exists(cache_key) and not no_config_cache:
|
||||
logger.info("Using cached config at {}".format(cache_key))
|
||||
return json.loads(open(cache_key).read())
|
||||
cli_logger.old_info(logger, "Using cached config at {}", cache_key)
|
||||
|
||||
config_cache = json.loads(open(cache_key).read())
|
||||
if config_cache.get("_version", -1) == CONFIG_CACHE_VERSION:
|
||||
# todo: is it fine to re-resolve? afaik it should be.
|
||||
# we can have migrations otherwise or something
|
||||
# but this seems overcomplicated given that resolving is
|
||||
# relatively cheap
|
||||
try_reload_log_state(config_cache["config"]["provider"],
|
||||
config_cache.get("provider_log_info"))
|
||||
cli_logger.verbose("Loaded cached config from " + cf.bold("{}"),
|
||||
cache_key)
|
||||
|
||||
return config_cache["config"]
|
||||
else:
|
||||
cli_logger.warning(
|
||||
"Found cached cluster config "
|
||||
"but the version " + cf.bold("{}") + " "
|
||||
"(expected " + cf.bold("{}") + ") does not match.\n"
|
||||
"This is normal if cluster launcher was updated.\n"
|
||||
"Config will be re-resolved.",
|
||||
config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)
|
||||
validate_config(config)
|
||||
|
||||
importer = NODE_PROVIDERS.get(config["provider"]["type"])
|
||||
|
@ -122,17 +221,32 @@ def _bootstrap_config(config: Dict[str, Any],
|
|||
config["provider"]))
|
||||
|
||||
provider_cls = importer(config["provider"])
|
||||
resolved_config = provider_cls.bootstrap_config(config)
|
||||
|
||||
with cli_logger.timed( # todo: better message
|
||||
"Bootstraping {} config",
|
||||
PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])):
|
||||
resolved_config = provider_cls.bootstrap_config(config)
|
||||
|
||||
if not no_config_cache:
|
||||
with open(cache_key, "w") as f:
|
||||
f.write(json.dumps(resolved_config))
|
||||
config_cache = {
|
||||
"_version": CONFIG_CACHE_VERSION,
|
||||
"provider_log_info": try_get_log_state(config["provider"]),
|
||||
"config": resolved_config
|
||||
}
|
||||
f.write(json.dumps(config_cache))
|
||||
return resolved_config
|
||||
|
||||
|
||||
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
||||
override_cluster_name: Optional[str],
|
||||
keep_min_workers: bool):
|
||||
keep_min_workers: bool, log_old_style: bool,
|
||||
log_color: str, verbose: int):
|
||||
"""Destroys all nodes of a Ray cluster described by a config json."""
|
||||
cli_logger.old_style = log_old_style
|
||||
cli_logger.color_mode = log_color
|
||||
cli_logger.verbosity = verbose
|
||||
cli_logger.dump_command_output = verbose == 3 # todo: add a separate flag?
|
||||
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
if override_cluster_name is not None:
|
||||
|
@ -140,7 +254,8 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
|||
config = prepare_config(config)
|
||||
validate_config(config)
|
||||
|
||||
confirm("This will destroy your cluster", yes)
|
||||
cli_logger.confirm(yes, "Destroying cluster.", _abort=True)
|
||||
cli_logger.old_confirm("This will destroy your cluster", yes)
|
||||
|
||||
if not workers_only:
|
||||
try:
|
||||
|
@ -155,8 +270,17 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
|||
override_cluster_name=override_cluster_name,
|
||||
port_forward=None,
|
||||
with_output=False)
|
||||
except Exception:
|
||||
logger.exception("Ignoring error attempting a clean shutdown.")
|
||||
except Exception as e:
|
||||
cli_logger.verbose_error(e) # todo: add better exception info
|
||||
cli_logger.warning(
|
||||
"Exception occured when stopping the cluster Ray runtime "
|
||||
"(use -v to dump teardown exceptions).")
|
||||
cli_logger.warning(
|
||||
"Ignoring the exception and "
|
||||
"attempting to shut down the cluster nodes anyway.")
|
||||
|
||||
cli_logger.old_exception(
|
||||
logger, "Ignoring error attempting a clean shutdown.")
|
||||
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
try:
|
||||
|
@ -169,11 +293,23 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
|||
|
||||
if keep_min_workers:
|
||||
min_workers = config.get("min_workers", 0)
|
||||
logger.info("teardown_cluster: "
|
||||
"Keeping {} nodes...".format(min_workers))
|
||||
|
||||
cli_logger.print(
|
||||
"{} random worker nodes will not be shut down. " +
|
||||
cf.gray("(due to {})"), cf.bold(min_workers),
|
||||
cf.bold("--keep-min-workers"))
|
||||
cli_logger.old_info(logger,
|
||||
"teardown_cluster: Keeping {} nodes...",
|
||||
min_workers)
|
||||
|
||||
workers = random.sample(workers, len(workers) - min_workers)
|
||||
|
||||
# todo: it's weird to kill the head node but not all workers
|
||||
if workers_only:
|
||||
cli_logger.print(
|
||||
"The head node will not be shut down. " +
|
||||
cf.gray("(due to {})"), cf.bold("--workers-only"))
|
||||
|
||||
return workers
|
||||
|
||||
head = provider.non_terminated_nodes({
|
||||
|
@ -187,11 +323,21 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
|||
A = remaining_nodes()
|
||||
with LogTimer("teardown_cluster: done."):
|
||||
while A:
|
||||
logger.info("teardown_cluster: "
|
||||
"Shutting down {} nodes...".format(len(A)))
|
||||
cli_logger.old_info(
|
||||
logger, "teardown_cluster: "
|
||||
"Shutting down {} nodes...", len(A))
|
||||
|
||||
provider.terminate_nodes(A)
|
||||
time.sleep(1)
|
||||
|
||||
cli_logger.print(
|
||||
"Requested {} nodes to shut down.",
|
||||
cf.bold(len(A)),
|
||||
_tags=dict(interval="1s"))
|
||||
|
||||
time.sleep(1) # todo: interval should be a variable
|
||||
A = remaining_nodes()
|
||||
cli_logger.print("{} nodes remaining after 1 second.",
|
||||
cf.bold(len(A)))
|
||||
finally:
|
||||
provider.cleanup()
|
||||
|
||||
|
@ -261,12 +407,23 @@ def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
|
|||
def warn_about_bad_start_command(start_commands):
|
||||
ray_start_cmd = list(filter(lambda x: "ray start" in x, start_commands))
|
||||
if len(ray_start_cmd) == 0:
|
||||
logger.warning(
|
||||
cli_logger.warning(
|
||||
"Ray runtime will not be started because `{}` is not in `{}`.",
|
||||
cf.bold("ray start"), cf.bold("head_start_ray_commands"))
|
||||
cli_logger.old_warning(
|
||||
logger,
|
||||
"Ray start is not included in the head_start_ray_commands section."
|
||||
)
|
||||
if not any("autoscaling-config" in x for x in ray_start_cmd):
|
||||
logger.warning(
|
||||
"Ray start on the head node does not have the flag"
|
||||
cli_logger.warning(
|
||||
"The head node will not launch any workers because "
|
||||
"`{}` does not have `{}` set.\n"
|
||||
"Potential fix: add `{}` to the `{}` command under `{}`.",
|
||||
cf.bold("ray start"), cf.bold("--autoscaling-config"),
|
||||
cf.bold("--autoscaling-config=~/ray_bootstrap_config.yaml"),
|
||||
cf.bold("ray start"), cf.bold("head_start_ray_commands"))
|
||||
logger.old_warning(
|
||||
logger, "Ray start on the head node does not have the flag"
|
||||
"--autoscaling-config set. The head node will not launch"
|
||||
"workers. Add --autoscaling-config=~/ray_bootstrap_config.yaml"
|
||||
"to ray start in the head_start_ray_commands section.")
|
||||
|
@ -288,129 +445,205 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
|||
head_node = None
|
||||
|
||||
if not head_node:
|
||||
confirm("This will create a new cluster", yes)
|
||||
cli_logger.confirm(
|
||||
yes,
|
||||
"No head node found. "
|
||||
"Launching a new cluster.",
|
||||
_abort=True)
|
||||
cli_logger.old_confirm("This will create a new cluster", yes)
|
||||
elif not no_restart:
|
||||
confirm("This will restart cluster services", yes)
|
||||
cli_logger.old_confirm("This will restart cluster services", yes)
|
||||
|
||||
if head_node:
|
||||
if restart_only:
|
||||
cli_logger.confirm(
|
||||
yes,
|
||||
"Updating cluster configuration and "
|
||||
"restarting the cluster Ray runtime. "
|
||||
"Setup commands will not be run due to `{}`.\n",
|
||||
cf.bold("--restart-only"),
|
||||
_abort=True)
|
||||
elif no_restart:
|
||||
cli_logger.print(
|
||||
"Cluster Ray runtime will not be restarted due "
|
||||
"to `{}`.", cf.bold("--no-restart"))
|
||||
cli_logger.confirm(
|
||||
yes,
|
||||
"Updating cluster configuration and "
|
||||
"running setup commands.",
|
||||
_abort=True)
|
||||
else:
|
||||
cli_logger.print(
|
||||
"Updating cluster configuration and running full setup.")
|
||||
cli_logger.confirm(
|
||||
yes,
|
||||
cf.bold("Cluster Ray runtime will be restarted."),
|
||||
_abort=True)
|
||||
cli_logger.newline()
|
||||
|
||||
launch_hash = hash_launch_conf(config["head_node"], config["auth"])
|
||||
if head_node is None or provider.node_tags(head_node).get(
|
||||
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
|
||||
if head_node is not None:
|
||||
confirm("Head node config out-of-date. It will be terminated",
|
||||
with cli_logger.group("Acquiring an up-to-date head node"):
|
||||
if head_node is not None:
|
||||
cli_logger.print(
|
||||
"Currently running head node is out-of-date with "
|
||||
"cluster configuration")
|
||||
cli_logger.print(
|
||||
"hash is {}, expected {}",
|
||||
cf.bold(
|
||||
provider.node_tags(head_node)
|
||||
.get(TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
|
||||
cli_logger.confirm(yes, "Relaunching it.", _abort=True)
|
||||
cli_logger.old_confirm(
|
||||
"Head node config out-of-date. It will be terminated",
|
||||
yes)
|
||||
logger.info(
|
||||
"get_or_create_head_node: "
|
||||
"Shutting down outdated head node {}".format(head_node))
|
||||
provider.terminate_node(head_node)
|
||||
logger.info("get_or_create_head_node: Launching new head node...")
|
||||
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
config["cluster_name"])
|
||||
provider.create_node(config["head_node"], head_node_tags, 1)
|
||||
|
||||
start = time.time()
|
||||
head_node = None
|
||||
while True:
|
||||
if time.time() - start > 50:
|
||||
raise RuntimeError("Failed to create head node.")
|
||||
nodes = provider.non_terminated_nodes(head_node_tags)
|
||||
if len(nodes) == 1:
|
||||
head_node = nodes[0]
|
||||
break
|
||||
time.sleep(1)
|
||||
cli_logger.old_info(
|
||||
logger, "get_or_create_head_node: "
|
||||
"Shutting down outdated head node {}", head_node)
|
||||
|
||||
# TODO(ekl) right now we always update the head node even if the hash
|
||||
# matches. We could prompt the user for what they want to do here.
|
||||
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
|
||||
logger.info("get_or_create_head_node: Updating files on head node...")
|
||||
provider.terminate_node(head_node)
|
||||
cli_logger.print("Terminated head node {}", head_node)
|
||||
|
||||
# Rewrite the auth config so that the head node can update the workers
|
||||
remote_config = copy.deepcopy(config)
|
||||
# drop proxy options if they exist, otherwise
|
||||
# head node won't be able to connect to workers
|
||||
remote_config["auth"].pop("ssh_proxy_command", None)
|
||||
if config["provider"]["type"] != "kubernetes":
|
||||
remote_key_path = "~/ray_bootstrap_key.pem"
|
||||
remote_config["auth"]["ssh_private_key"] = remote_key_path
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"get_or_create_head_node: Launching new head node...")
|
||||
|
||||
# Adjust for new file locations
|
||||
new_mounts = {}
|
||||
for remote_path in config["file_mounts"]:
|
||||
new_mounts[remote_path] = remote_path
|
||||
remote_config["file_mounts"] = new_mounts
|
||||
remote_config["no_restart"] = no_restart
|
||||
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
config["cluster_name"])
|
||||
provider.create_node(config["head_node"], head_node_tags, 1)
|
||||
cli_logger.print("Launched a new head node")
|
||||
|
||||
# Now inject the rewritten config and SSH key into the head node
|
||||
remote_config_file = tempfile.NamedTemporaryFile(
|
||||
"w", prefix="ray-bootstrap-")
|
||||
remote_config_file.write(json.dumps(remote_config))
|
||||
remote_config_file.flush()
|
||||
config["file_mounts"].update({
|
||||
"~/ray_bootstrap_config.yaml": remote_config_file.name
|
||||
})
|
||||
if config["provider"]["type"] != "kubernetes":
|
||||
start = time.time()
|
||||
head_node = None
|
||||
with cli_logger.timed("Fetching the new head node"):
|
||||
while True:
|
||||
if time.time() - start > 50:
|
||||
cli_logger.abort(
|
||||
"Head node fetch timed out.") # todo: msg
|
||||
raise RuntimeError("Failed to create head node.")
|
||||
nodes = provider.non_terminated_nodes(head_node_tags)
|
||||
if len(nodes) == 1:
|
||||
head_node = nodes[0]
|
||||
break
|
||||
time.sleep(1)
|
||||
cli_logger.newline()
|
||||
|
||||
with cli_logger.group(
|
||||
"Setting up head node",
|
||||
_numbered=("<>", 1, 1),
|
||||
# cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
|
||||
_tags=dict()): # add id, ARN to tags?
|
||||
|
||||
# TODO(ekl) right now we always update the head node even if the
|
||||
# hash matches.
|
||||
# We could prompt the user for what they want to do here.
|
||||
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
|
||||
|
||||
cli_logger.old_info(
|
||||
logger,
|
||||
"get_or_create_head_node: Updating files on head node...")
|
||||
|
||||
# Rewrite the auth config so that the head
|
||||
# node can update the workers
|
||||
remote_config = copy.deepcopy(config)
|
||||
|
||||
# drop proxy options if they exist, otherwise
|
||||
# head node won't be able to connect to workers
|
||||
remote_config["auth"].pop("ssh_proxy_command", None)
|
||||
|
||||
if config["provider"]["type"] != "kubernetes":
|
||||
remote_key_path = "~/ray_bootstrap_key.pem"
|
||||
remote_config["auth"]["ssh_private_key"] = remote_key_path
|
||||
|
||||
# Adjust for new file locations
|
||||
new_mounts = {}
|
||||
for remote_path in config["file_mounts"]:
|
||||
new_mounts[remote_path] = remote_path
|
||||
remote_config["file_mounts"] = new_mounts
|
||||
remote_config["no_restart"] = no_restart
|
||||
|
||||
# Now inject the rewritten config and SSH key into the head node
|
||||
remote_config_file = tempfile.NamedTemporaryFile(
|
||||
"w", prefix="ray-bootstrap-")
|
||||
remote_config_file.write(json.dumps(remote_config))
|
||||
remote_config_file.flush()
|
||||
config["file_mounts"].update({
|
||||
remote_key_path: config["auth"]["ssh_private_key"],
|
||||
"~/ray_bootstrap_config.yaml": remote_config_file.name
|
||||
})
|
||||
|
||||
if restart_only:
|
||||
init_commands = []
|
||||
ray_start_commands = config["head_start_ray_commands"]
|
||||
elif no_restart:
|
||||
init_commands = config["head_setup_commands"]
|
||||
ray_start_commands = []
|
||||
else:
|
||||
init_commands = config["head_setup_commands"]
|
||||
ray_start_commands = config["head_start_ray_commands"]
|
||||
if config["provider"]["type"] != "kubernetes":
|
||||
config["file_mounts"].update({
|
||||
remote_key_path: config["auth"]["ssh_private_key"],
|
||||
})
|
||||
cli_logger.print("Prepared bootstrap config")
|
||||
|
||||
if not no_restart:
|
||||
warn_about_bad_start_command(ray_start_commands)
|
||||
if restart_only:
|
||||
init_commands = []
|
||||
ray_start_commands = config["head_start_ray_commands"]
|
||||
elif no_restart:
|
||||
init_commands = config["head_setup_commands"]
|
||||
ray_start_commands = []
|
||||
else:
|
||||
init_commands = config["head_setup_commands"]
|
||||
ray_start_commands = config["head_start_ray_commands"]
|
||||
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=head_node,
|
||||
provider_config=config["provider"],
|
||||
provider=provider,
|
||||
auth_config=config["auth"],
|
||||
cluster_name=config["cluster_name"],
|
||||
file_mounts=config["file_mounts"],
|
||||
initialization_commands=config["initialization_commands"],
|
||||
setup_commands=init_commands,
|
||||
ray_start_commands=ray_start_commands,
|
||||
runtime_hash=runtime_hash,
|
||||
docker_config=config.get("docker"))
|
||||
updater.start()
|
||||
updater.join()
|
||||
if not no_restart:
|
||||
warn_about_bad_start_command(ray_start_commands)
|
||||
|
||||
# Refresh the node cache so we see the external ip if available
|
||||
provider.non_terminated_nodes(head_node_tags)
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=head_node,
|
||||
provider_config=config["provider"],
|
||||
provider=provider,
|
||||
auth_config=config["auth"],
|
||||
cluster_name=config["cluster_name"],
|
||||
file_mounts=config["file_mounts"],
|
||||
initialization_commands=config["initialization_commands"],
|
||||
setup_commands=init_commands,
|
||||
ray_start_commands=ray_start_commands,
|
||||
runtime_hash=runtime_hash,
|
||||
docker_config=config.get("docker"))
|
||||
updater.start()
|
||||
updater.join()
|
||||
|
||||
if config.get("provider", {}).get("use_internal_ips", False) is True:
|
||||
head_node_ip = provider.internal_ip(head_node)
|
||||
else:
|
||||
head_node_ip = provider.external_ip(head_node)
|
||||
# Refresh the node cache so we see the external ip if available
|
||||
provider.non_terminated_nodes(head_node_tags)
|
||||
|
||||
if updater.exitcode != 0:
|
||||
logger.error("get_or_create_head_node: "
|
||||
"Updating {} failed".format(head_node_ip))
|
||||
sys.exit(1)
|
||||
logger.info(
|
||||
"get_or_create_head_node: "
|
||||
"Head node up-to-date, IP address is: {}".format(head_node_ip))
|
||||
if config.get("provider", {}).get("use_internal_ips",
|
||||
False) is True:
|
||||
head_node_ip = provider.internal_ip(head_node)
|
||||
else:
|
||||
head_node_ip = provider.external_ip(head_node)
|
||||
|
||||
monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
|
||||
if override_cluster_name:
|
||||
modifiers = " --cluster-name={}".format(
|
||||
quote(override_cluster_name))
|
||||
else:
|
||||
modifiers = ""
|
||||
print("To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ray exec {} {}{}\n".format(config_file, quote(monitor_str),
|
||||
modifiers))
|
||||
print("To open a console on the cluster:\n\n"
|
||||
" ray attach {}{}\n".format(config_file, modifiers))
|
||||
if updater.exitcode != 0:
|
||||
# todo: this does not follow the mockup and is not good enough
|
||||
cli_logger.abort("Failed to setup head node.")
|
||||
|
||||
print("To get a remote shell to the cluster manually, run:\n\n"
|
||||
" {}\n".format(updater.cmd_runner.remote_shell_command_str()))
|
||||
cli_logger.old_error(
|
||||
logger, "get_or_create_head_node: "
|
||||
"Updating {} failed", head_node_ip)
|
||||
sys.exit(1)
|
||||
logger.info(
|
||||
"get_or_create_head_node: "
|
||||
"Head node up-to-date, IP address is: {}".format(head_node_ip))
|
||||
|
||||
monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
|
||||
if override_cluster_name:
|
||||
modifiers = " --cluster-name={}".format(
|
||||
quote(override_cluster_name))
|
||||
else:
|
||||
modifiers = ""
|
||||
print("To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ray exec {} {}{}\n".format(config_file,
|
||||
quote(monitor_str), modifiers))
|
||||
print("To open a console on the cluster:\n\n"
|
||||
" ray attach {}{}\n".format(config_file, modifiers))
|
||||
|
||||
print("To get a remote shell to the cluster manually, run:\n\n"
|
||||
" {}\n".format(
|
||||
updater.cmd_runner.remote_shell_command_str()))
|
||||
finally:
|
||||
provider.cleanup()
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import datetime
|
||||
import logging
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -13,6 +15,9 @@ class LogTimer:
|
|||
self._start_time = datetime.datetime.utcnow()
|
||||
|
||||
def __exit__(self, *error_vals):
|
||||
if not cli_logger.old_style:
|
||||
return
|
||||
|
||||
td = datetime.datetime.utcnow() - self._start_time
|
||||
status = ""
|
||||
if self._show_status:
|
||||
|
|
|
@ -77,6 +77,15 @@ NODE_PROVIDERS = {
|
|||
"docker": None,
|
||||
"external": import_external # Import an external module
|
||||
}
|
||||
PROVIDER_PRETTY_NAMES = {
|
||||
"local": "Local",
|
||||
"aws": "AWS",
|
||||
"gcp": "GCP",
|
||||
"azure": "Azure",
|
||||
"kubernetes": "Kubernetes",
|
||||
# "docker": "Docker", # not supported
|
||||
"external": "External"
|
||||
}
|
||||
|
||||
DEFAULT_CONFIGS = {
|
||||
"local": load_local_example_config,
|
||||
|
@ -88,6 +97,26 @@ DEFAULT_CONFIGS = {
|
|||
}
|
||||
|
||||
|
||||
def try_logging_config(config):
|
||||
if config["provider"]["type"] == "aws":
|
||||
from ray.autoscaler.aws.config import log_to_cli
|
||||
log_to_cli(config)
|
||||
|
||||
|
||||
def try_get_log_state(provider_config):
|
||||
if provider_config["type"] == "aws":
|
||||
from ray.autoscaler.aws.config import get_log_state
|
||||
return get_log_state()
|
||||
|
||||
|
||||
def try_reload_log_state(provider_config, log_state):
|
||||
if not log_state:
|
||||
return
|
||||
if provider_config["type"] == "aws":
|
||||
from ray.autoscaler.aws.config import reload_log_state
|
||||
return reload_log_state(log_state)
|
||||
|
||||
|
||||
def load_class(path):
|
||||
"""
|
||||
Load a class at runtime given a full path.
|
||||
|
|
|
@ -12,6 +12,9 @@ from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, \
|
|||
from ray.autoscaler.command_runner import NODE_START_WAIT_S, SSHOptions
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
import colorful as cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
READY_CHECK_INTERVAL = 5
|
||||
|
@ -57,8 +60,9 @@ class NodeUpdater:
|
|||
self.auth_config = auth_config
|
||||
|
||||
def run(self):
|
||||
logger.info(self.log_prefix +
|
||||
"Updating to {}".format(self.runtime_hash))
|
||||
cli_logger.old_info(logger, "{}Updating to {}", self.log_prefix,
|
||||
self.runtime_hash)
|
||||
|
||||
try:
|
||||
with LogTimer(self.log_prefix +
|
||||
"Applied config {}".format(self.runtime_hash)):
|
||||
|
@ -68,11 +72,28 @@ class NodeUpdater:
|
|||
if hasattr(e, "cmd"):
|
||||
error_str = "(Exit Status {}) {}".format(
|
||||
e.returncode, " ".join(e.cmd))
|
||||
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
|
||||
logger.error(self.log_prefix +
|
||||
"Error executing: {}".format(error_str) + "\n")
|
||||
cli_logger.error("New status: {}", cf.bold(STATUS_UPDATE_FAILED))
|
||||
|
||||
cli_logger.old_error(logger, "{}Error executing: {}\n",
|
||||
self.log_prefix, error_str)
|
||||
|
||||
cli_logger.error("!!!")
|
||||
if hasattr(e, "cmd"):
|
||||
cli_logger.error(
|
||||
"Setup command `{}` failed with exit code {}. stderr:",
|
||||
cf.bold(e.cmd), e.returncode)
|
||||
else:
|
||||
cli_logger.verbose_error(vars(e), _no_format=True)
|
||||
cli_logger.error(str(e)) # todo: handle this better somehow?
|
||||
# todo: print stderr here
|
||||
cli_logger.error("!!!")
|
||||
cli_logger.newline()
|
||||
|
||||
if isinstance(e, click.ClickException):
|
||||
# todo: why do we ignore this here
|
||||
return
|
||||
raise
|
||||
|
||||
|
@ -81,98 +102,173 @@ class NodeUpdater:
|
|||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
|
||||
})
|
||||
cli_logger.labeled_value("New status", STATUS_UP_TO_DATE)
|
||||
|
||||
self.exitcode = 0
|
||||
|
||||
def sync_file_mounts(self, sync_cmd):
|
||||
# Rsync file mounts
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
assert os.path.exists(local_path), local_path
|
||||
if os.path.isdir(local_path):
|
||||
if not local_path.endswith("/"):
|
||||
local_path += "/"
|
||||
if not remote_path.endswith("/"):
|
||||
remote_path += "/"
|
||||
nolog_paths = []
|
||||
if cli_logger.verbosity == 0:
|
||||
nolog_paths = [
|
||||
"~/ray_bootstrap_key.pem", "~/ray_bootstrap_config.yaml"
|
||||
]
|
||||
|
||||
with LogTimer(self.log_prefix +
|
||||
"Synced {} to {}".format(local_path, remote_path)):
|
||||
self.cmd_runner.run("mkdir -p {}".format(
|
||||
os.path.dirname(remote_path)))
|
||||
sync_cmd(local_path, remote_path)
|
||||
# Rsync file mounts
|
||||
with cli_logger.group(
|
||||
"Processing file mounts", _numbered=("[]", 2, 5)):
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
assert os.path.exists(local_path), local_path
|
||||
if os.path.isdir(local_path):
|
||||
if not local_path.endswith("/"):
|
||||
local_path += "/"
|
||||
if not remote_path.endswith("/"):
|
||||
remote_path += "/"
|
||||
|
||||
with LogTimer(self.log_prefix + "Synced {} to {}".format(
|
||||
local_path, remote_path)):
|
||||
self.cmd_runner.run("mkdir -p {}".format(
|
||||
os.path.dirname(remote_path)))
|
||||
sync_cmd(local_path, remote_path)
|
||||
|
||||
if remote_path not in nolog_paths:
|
||||
# todo: timed here?
|
||||
cli_logger.print("{} from {}", cf.bold(remote_path),
|
||||
cf.bold(local_path))
|
||||
|
||||
def wait_ready(self, deadline):
|
||||
with LogTimer(self.log_prefix + "Got remote shell"):
|
||||
logger.info(self.log_prefix + "Waiting for remote shell...")
|
||||
with cli_logger.group(
|
||||
"Waiting for SSH to become available", _numbered=("[]", 1, 5)):
|
||||
with LogTimer(self.log_prefix + "Got remote shell"):
|
||||
cli_logger.old_info(logger, "{}Waiting for remote shell...",
|
||||
self.log_prefix)
|
||||
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
try:
|
||||
logger.debug(self.log_prefix +
|
||||
"Waiting for remote shell...")
|
||||
cli_logger.print("Running `{}` as a test.", cf.bold("uptime"))
|
||||
while time.time() < deadline and \
|
||||
not self.provider.is_terminated(self.node_id):
|
||||
try:
|
||||
cli_logger.old_debug(logger,
|
||||
"{}Waiting for remote shell...",
|
||||
self.log_prefix)
|
||||
|
||||
self.cmd_runner.run("uptime", timeout=5)
|
||||
logger.debug("Uptime succeeded.")
|
||||
return True
|
||||
self.cmd_runner.run("uptime")
|
||||
cli_logger.old_debug(logger, "Uptime succeeded.")
|
||||
cli_logger.success("Success.")
|
||||
return True
|
||||
except Exception as e:
|
||||
retry_str = str(e)
|
||||
if hasattr(e, "cmd"):
|
||||
retry_str = "(Exit Status {}): {}".format(
|
||||
e.returncode, " ".join(e.cmd))
|
||||
|
||||
except Exception as e:
|
||||
retry_str = str(e)
|
||||
if hasattr(e, "cmd"):
|
||||
retry_str = "(Exit Status {}): {}".format(
|
||||
e.returncode, " ".join(e.cmd))
|
||||
logger.debug(self.log_prefix +
|
||||
"Node not up, retrying: {}".format(retry_str))
|
||||
time.sleep(READY_CHECK_INTERVAL)
|
||||
cli_logger.print(
|
||||
"SSH still not available {}, "
|
||||
"retrying in {} seconds.", cf.gray(retry_str),
|
||||
cf.bold(str(READY_CHECK_INTERVAL)))
|
||||
cli_logger.old_debug(logger,
|
||||
"{}Node not up, retrying: {}",
|
||||
self.log_prefix, retry_str)
|
||||
|
||||
time.sleep(READY_CHECK_INTERVAL)
|
||||
|
||||
assert False, "Unable to connect to node"
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
|
||||
cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)
|
||||
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
self.wait_ready(deadline)
|
||||
|
||||
node_tags = self.provider.node_tags(self.node_id)
|
||||
logger.debug("Node tags: {}".format(str(node_tags)))
|
||||
if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
|
||||
logger.info(self.log_prefix +
|
||||
"{} already up-to-date, skip to ray start".format(
|
||||
self.node_id))
|
||||
# todo: we lie in the confirmation message since
|
||||
# full setup might be cancelled here
|
||||
cli_logger.print(
|
||||
"Configuration already up to date, "
|
||||
"skipping file mounts, initalization and setup commands.")
|
||||
cli_logger.old_info(logger,
|
||||
"{}{} already up-to-date, skip to ray start",
|
||||
self.log_prefix, self.node_id)
|
||||
else:
|
||||
cli_logger.print(
|
||||
"Updating cluster configuration.",
|
||||
_tags=dict(hash=self.runtime_hash))
|
||||
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
|
||||
cli_logger.labeled_value("New status", STATUS_SYNCING_FILES)
|
||||
self.sync_file_mounts(self.rsync_up)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
|
||||
with LogTimer(
|
||||
self.log_prefix + "Initialization commands",
|
||||
show_status=True):
|
||||
for cmd in self.initialization_commands:
|
||||
self.cmd_runner.run(
|
||||
cmd,
|
||||
ssh_options_override=SSHOptions(
|
||||
self.auth_config.get("ssh_private_key")))
|
||||
cli_logger.labeled_value("New status", STATUS_SETTING_UP)
|
||||
|
||||
if self.initialization_commands:
|
||||
with cli_logger.group(
|
||||
"Running initialization commands",
|
||||
_numbered=("[]", 3, 5)): # todo: fix command numbering
|
||||
with LogTimer(
|
||||
self.log_prefix + "Initialization commands",
|
||||
show_status=True):
|
||||
|
||||
for cmd in self.initialization_commands:
|
||||
self.cmd_runner.run(
|
||||
cmd,
|
||||
ssh_options_override=SSHOptions(
|
||||
self.auth_config.get("ssh_private_key")))
|
||||
else:
|
||||
cli_logger.print(
|
||||
"No initialization commands to run.",
|
||||
_numbered=("[]", 3, 5))
|
||||
|
||||
if self.setup_commands:
|
||||
with cli_logger.group(
|
||||
"Running setup commands",
|
||||
_numbered=("[]", 4, 5)): # todo: fix command numbering
|
||||
with LogTimer(
|
||||
self.log_prefix + "Setup commands",
|
||||
show_status=True):
|
||||
|
||||
total = len(self.setup_commands)
|
||||
for i, cmd in enumerate(self.setup_commands):
|
||||
if cli_logger.verbosity == 0:
|
||||
cmd_to_print = cf.bold(cmd[:30]) + "..."
|
||||
else:
|
||||
cmd_to_print = cf.bold(cmd)
|
||||
|
||||
cli_logger.print(
|
||||
cmd_to_print, _numbered=("()", i, total))
|
||||
|
||||
self.cmd_runner.run(cmd)
|
||||
else:
|
||||
cli_logger.print(
|
||||
"No setup commands to run.", _numbered=("[]", 4, 5))
|
||||
|
||||
with cli_logger.group(
|
||||
"Starting the Ray runtime", _numbered=("[]", 5, 5)):
|
||||
with LogTimer(
|
||||
self.log_prefix + "Setup commands", show_status=True):
|
||||
for cmd in self.setup_commands:
|
||||
self.log_prefix + "Ray start commands", show_status=True):
|
||||
for cmd in self.ray_start_commands:
|
||||
self.cmd_runner.run(cmd)
|
||||
|
||||
with LogTimer(
|
||||
self.log_prefix + "Ray start commands", show_status=True):
|
||||
for cmd in self.ray_start_commands:
|
||||
self.cmd_runner.run(cmd)
|
||||
|
||||
def rsync_up(self, source, target):
|
||||
logger.info(self.log_prefix +
|
||||
"Syncing {} to {}...".format(source, target))
|
||||
cli_logger.old_info(logger, "{}Syncing {} to {}...", self.log_prefix,
|
||||
source, target)
|
||||
|
||||
self.cmd_runner.run_rsync_up(source, target)
|
||||
cli_logger.verbose("`rsync`ed {} (local) to {} (remote)",
|
||||
cf.bold(source), cf.bold(target))
|
||||
|
||||
def rsync_down(self, source, target):
|
||||
logger.info(self.log_prefix +
|
||||
"Syncing {} from {}...".format(source, target))
|
||||
cli_logger.old_info(logger, "{}Syncing {} from {}...", self.log_prefix,
|
||||
source, target)
|
||||
|
||||
self.cmd_runner.run_rsync_down(source, target)
|
||||
cli_logger.verbose("`rsync`ed {} (remote) to {} (local)",
|
||||
cf.bold(source), cf.bold(target))
|
||||
|
||||
|
||||
class NodeUpdaterThread(NodeUpdater, Thread):
|
||||
|
|
|
@ -646,6 +646,16 @@ def stop(force, verbose):
|
|||
|
||||
@cli.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--min-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Override the configured min worker node count for the cluster.")
|
||||
@click.option(
|
||||
"--max-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Override the configured max worker node count for the cluster.")
|
||||
@click.option(
|
||||
"--no-restart",
|
||||
is_flag=True,
|
||||
|
@ -659,20 +669,11 @@ def stop(force, verbose):
|
|||
help=("Whether to skip running setup commands and only restart Ray. "
|
||||
"This cannot be used with 'no-restart'."))
|
||||
@click.option(
|
||||
"--no-config-cache",
|
||||
"--yes",
|
||||
"-y",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Disable the local cluster config cache.")
|
||||
@click.option(
|
||||
"--min-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Override the configured min worker node count for the cluster.")
|
||||
@click.option(
|
||||
"--max-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Override the configured max worker node count for the cluster.")
|
||||
help="Don't ask for confirmation.")
|
||||
@click.option(
|
||||
"--cluster-name",
|
||||
"-n",
|
||||
|
@ -680,13 +681,25 @@ def stop(force, verbose):
|
|||
type=str,
|
||||
help="Override the configured cluster name.")
|
||||
@click.option(
|
||||
"--yes",
|
||||
"-y",
|
||||
"--no-config-cache",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Don't ask for confirmation.")
|
||||
help="Disable the local cluster config cache.")
|
||||
@click.option(
|
||||
"--log-old-style/--log-new-style",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help=("Use old logging."))
|
||||
@click.option(
|
||||
"--log-color",
|
||||
required=False,
|
||||
type=str,
|
||||
default="auto",
|
||||
help=("Use color logging. "
|
||||
"Valid values are: auto (if stdout is a tty), true, false."))
|
||||
@click.option("-v", "--verbose", count=True)
|
||||
def up(cluster_config_file, min_workers, max_workers, no_restart, restart_only,
|
||||
yes, cluster_name, no_config_cache):
|
||||
yes, cluster_name, no_config_cache, log_old_style, log_color, verbose):
|
||||
"""Create or update a Ray cluster."""
|
||||
if restart_only or no_restart:
|
||||
assert restart_only != no_restart, "Cannot set both 'restart_only' " \
|
||||
|
@ -703,38 +716,52 @@ def up(cluster_config_file, min_workers, max_workers, no_restart, restart_only,
|
|||
logger.info("Error downloading file: ", e)
|
||||
create_or_update_cluster(cluster_config_file, min_workers, max_workers,
|
||||
no_restart, restart_only, yes, cluster_name,
|
||||
no_config_cache)
|
||||
no_config_cache, log_old_style, log_color,
|
||||
verbose)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--workers-only",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Only destroy the workers.")
|
||||
@click.option(
|
||||
"--keep-min-workers",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Retain the minimal amount of workers specified in the config.")
|
||||
@click.option(
|
||||
"--yes",
|
||||
"-y",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Don't ask for confirmation.")
|
||||
@click.option(
|
||||
"--workers-only",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Only destroy the workers.")
|
||||
@click.option(
|
||||
"--cluster-name",
|
||||
"-n",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Override the configured cluster name.")
|
||||
@click.option(
|
||||
"--keep-min-workers",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Retain the minimal amount of workers specified in the config.")
|
||||
@click.option(
|
||||
"--log-old-style/--log-new-style",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help=("Use old logging."))
|
||||
@click.option(
|
||||
"--log-color",
|
||||
required=False,
|
||||
type=str,
|
||||
default="auto",
|
||||
help=("Use color logging. "
|
||||
"Valid values are: auto (if stdout is a tty), true, false."))
|
||||
@click.option("-v", "--verbose", count=True)
|
||||
def down(cluster_config_file, yes, workers_only, cluster_name,
|
||||
keep_min_workers):
|
||||
keep_min_workers, log_old_style, log_color, verbose):
|
||||
"""Tear down a Ray cluster."""
|
||||
teardown_cluster(cluster_config_file, yes, workers_only, cluster_name,
|
||||
keep_min_workers)
|
||||
keep_min_workers, log_old_style, log_color, verbose)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
|
|
@ -303,6 +303,7 @@ install_requires = [
|
|||
"aiohttp",
|
||||
"click >= 7.0",
|
||||
"colorama",
|
||||
"colorful",
|
||||
"filelock",
|
||||
"google",
|
||||
"gpustat",
|
||||
|
|
Loading…
Add table
Reference in a new issue