[Autoscaler] Command Line Interface improvements (#9322)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Maksim Smolin 2020-07-22 12:21:44 -07:00 committed by GitHub
parent 456e012029
commit 908c0c630a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1475 additions and 332 deletions

View file

@ -13,7 +13,11 @@ import botocore
from ray.ray_constants import BOTO_MAX_RETRIES
from ray.autoscaler.tags import NODE_TYPE_WORKER, NODE_TYPE_HEAD
from ray.autoscaler.aws.utils import LazyDefaultDict
from ray.autoscaler.aws.utils import LazyDefaultDict, handle_boto_error
from ray.autoscaler.node_provider import PROVIDER_PRETTY_NAMES
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
logger = logging.getLogger(__name__)
@ -45,7 +49,8 @@ DEFAULT_AMI = {
"sa-east-1": "ami-0da2c49fe75e7e5ed", # SA (Sao Paulo)
}
# todo: cli_logger should handle this assert properly
# this should probably also happens somewhere else
assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
"Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
@ -70,6 +75,112 @@ def key_pair(i, region, key_name):
# Suppress excessive connection dropped logs from boto
logging.getLogger("botocore").setLevel(logging.WARNING)
_log_info = {}
def reload_log_state(override_log_info):
_log_info.update(override_log_info)
def get_log_state():
return _log_info.copy()
def _set_config_info(**kwargs):
"""Record configuration artifacts useful for logging."""
# todo: this is technically fragile iff we ever use multiple configs
for k, v in kwargs.items():
_log_info[k] = v
def _arn_to_name(arn):
return arn.split(":")[-1].split("/")[-1]
def log_to_cli(config):
provider_name = PROVIDER_PRETTY_NAMES.get("aws", None)
cli_logger.doassert(provider_name is not None,
"Could not find a pretty name for the AWS provider.")
with cli_logger.group("{} config", provider_name):
def same_everywhere(key):
return config["head_node"][key] == config["worker_nodes"][key]
def print_info(resource_string,
key,
head_src_key,
workers_src_key,
allowed_tags=["default"],
list_value=False):
head_tags = {}
workers_tags = {}
if _log_info[head_src_key] in allowed_tags:
head_tags[_log_info[head_src_key]] = True
if _log_info[workers_src_key] in allowed_tags:
workers_tags[_log_info[workers_src_key]] = True
head_value_str = config["head_node"][key]
if list_value:
head_value_str = cli_logger.render_list(head_value_str)
if same_everywhere(key):
cli_logger.labeled_value( # todo: handle plural vs singular?
resource_string + " (head & workers)",
"{}",
head_value_str,
_tags=head_tags)
else:
workers_value_str = config["worker_nodes"][key]
if list_value:
workers_value_str = cli_logger.render_list(
workers_value_str)
cli_logger.labeled_value(
resource_string + " (head)",
"{}",
head_value_str,
_tags=head_tags)
cli_logger.labeled_value(
resource_string + " (workers)",
"{}",
workers_value_str,
_tags=workers_tags)
tags = {"default": _log_info["head_instance_profile_src"] == "default"}
cli_logger.labeled_value(
"IAM Profile",
"{}",
_arn_to_name(config["head_node"]["IamInstanceProfile"]["Arn"]),
_tags=tags)
print_info("EC2 Key pair", "KeyName", "keypair_src", "keypair_src")
print_info(
"VPC Subnets",
"SubnetIds",
"head_subnet_src",
"workers_subnet_src",
list_value=True)
print_info(
"EC2 Security groups",
"SecurityGroupIds",
"head_security_group_src",
"workers_security_group_src",
list_value=True)
print_info(
"EC2 AMI",
"ImageId",
"head_ami_src",
"workers_ami_src",
allowed_tags=["dlami"])
cli_logger.newline()
def bootstrap_aws(config):
# The head node needs to have an IAM role that allows it to create further
@ -94,27 +205,38 @@ def bootstrap_aws(config):
def _configure_iam_role(config):
if "IamInstanceProfile" in config["head_node"]:
_set_config_info(head_instance_profile_src="config")
return config
_set_config_info(head_instance_profile_src="default")
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
if profile is None:
logger.info("_configure_iam_role: "
"Creating new instance profile {}".format(
DEFAULT_RAY_INSTANCE_PROFILE))
cli_logger.verbose(
"Creating new IAM instance profile {} for use as the default.",
cf.bold(DEFAULT_RAY_INSTANCE_PROFILE))
cli_logger.old_info(
logger, "_configure_iam_role: "
"Creating new instance profile {}", DEFAULT_RAY_INSTANCE_PROFILE)
client = _client("iam", config)
client.create_instance_profile(
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)
time.sleep(15) # wait for propagation
cli_logger.doassert(profile is not None,
"Failed to create instance profile.") # todo: err msg
assert profile is not None, "Failed to create instance profile"
if not profile.roles:
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
if role is None:
logger.info("_configure_iam_role: "
"Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
cli_logger.verbose(
"Creating new IAM role {} for "
"use as the default instance role.",
cf.bold(DEFAULT_RAY_IAM_ROLE))
cli_logger.old_info(logger, "_configure_iam_role: "
"Creating new role {}", DEFAULT_RAY_IAM_ROLE)
iam = _resource("iam", config)
iam.create_role(
RoleName=DEFAULT_RAY_IAM_ROLE,
@ -130,6 +252,9 @@ def _configure_iam_role(config):
],
}))
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
cli_logger.doassert(role is not None,
"Failed to create role.") # todo: err msg
assert role is not None, "Failed to create role"
role.attach_policy(
PolicyArn="arn:aws:iam::aws:policy/AmazonEC2FullAccess")
@ -138,9 +263,9 @@ def _configure_iam_role(config):
profile.add_role(RoleName=role.name)
time.sleep(15) # wait for propagation
logger.info("_configure_iam_role: "
"Role not specified for head node, using {}".format(
profile.arn))
cli_logger.old_info(
logger, "_configure_iam_role: "
"Role not specified for head node, using {}", profile.arn)
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
return config
@ -148,9 +273,19 @@ def _configure_iam_role(config):
def _configure_key_pair(config):
if "ssh_private_key" in config["auth"]:
_set_config_info(keypair_src="config")
cli_logger.doassert( # todo: verify schema beforehand?
"KeyName" in config["head_node"],
"`KeyName` missing for head node.") # todo: err msg
cli_logger.doassert(
"KeyName" in config["worker_nodes"],
"`KeyName` missing for worker nodes.") # todo: err msg
assert "KeyName" in config["head_node"]
assert "KeyName" in config["worker_nodes"]
return config
_set_config_info(keypair_src="default")
ec2 = _resource("ec2", config)
@ -170,8 +305,12 @@ def _configure_key_pair(config):
# We can safely create a new key.
if not key and not os.path.exists(key_path):
logger.info("_configure_key_pair: "
"Creating new key pair {}".format(key_name))
cli_logger.verbose(
"Creating new key pair {} for use as the default.",
cf.bold(key_name))
cli_logger.old_info(
logger, "_configure_key_pair: "
"Creating new key pair {}", key_name)
key = ec2.create_key_pair(KeyName=key_name)
# We need to make sure to _create_ the file with the right
@ -182,16 +321,25 @@ def _configure_key_pair(config):
break
if not key:
cli_logger.abort(
"No matching local key file for any of the key pairs in this "
"account with ids from 0..{}. "
"Consider deleting some unused keys pairs from your account.",
key_name) # todo: err msg
raise ValueError(
"No matching local key file for any of the key pairs in this "
"account with ids from 0..{}. ".format(key_name) +
"Consider deleting some unused keys pairs from your account.")
cli_logger.doassert(
os.path.exists(key_path), "Private key file " + cf.bold("{}") +
" not found for " + cf.bold("{}"), key_path, key_name) # todo: err msg
assert os.path.exists(key_path), \
"Private key file {} not found for {}".format(key_path, key_name)
logger.info("_configure_key_pair: "
"KeyName not specified for nodes, using {}".format(key_name))
cli_logger.old_info(
logger, "_configure_key_pair: "
"KeyName not specified for nodes, using {}", key_name)
config["auth"]["ssh_private_key"] = key_path
config["head_node"]["KeyName"] = key_name
@ -203,12 +351,25 @@ def _configure_key_pair(config):
def _configure_subnet(config):
ec2 = _resource("ec2", config)
use_internal_ips = config["provider"].get("use_internal_ips", False)
subnets = sorted(
(s for s in ec2.subnets.all() if s.state == "available" and (
use_internal_ips or s.map_public_ip_on_launch)),
reverse=True, # sort from Z-A
key=lambda subnet: subnet.availability_zone)
try:
subnets = sorted(
(s for s in ec2.subnets.all() if s.state == "available" and (
use_internal_ips or s.map_public_ip_on_launch)),
reverse=True, # sort from Z-A
key=lambda subnet: subnet.availability_zone)
except botocore.exceptions.ClientError as exc:
handle_boto_error(exc, "Failed to fetch available subnets from AWS.")
raise exc
if not subnets:
cli_logger.abort(
"No usable subnets found, try manually creating an instance in "
"your specified region to populate the list of subnets "
"and trying this again.\n"
"Note that the subnet must map public IPs "
"on instance launch unless you set `use_internal_ips: true` in "
"the `provider` config.") # todo: err msg
raise Exception(
"No usable subnets found, try manually creating an instance in "
"your specified region to populate the list of subnets "
@ -219,6 +380,12 @@ def _configure_subnet(config):
azs = config["provider"]["availability_zone"].split(",")
subnets = [s for s in subnets if s.availability_zone in azs]
if not subnets:
cli_logger.abort(
"No usable subnets matching availability zone {} found.\n"
"Choose a different availability zone or try "
"manually creating an instance in your specified region "
"to populate the list of subnets and trying this again.",
config["provider"]["availability_zone"]) # todo: err msg
raise Exception(
"No usable subnets matching availability zone {} "
"found. Choose a different availability zone or try "
@ -229,21 +396,31 @@ def _configure_subnet(config):
subnet_ids = [s.subnet_id for s in subnets]
subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets]
if "SubnetIds" not in config["head_node"]:
_set_config_info(head_subnet_src="default")
config["head_node"]["SubnetIds"] = subnet_ids
logger.info("_configure_subnet: "
"SubnetIds not specified for head node, using {}".format(
subnet_descr))
cli_logger.old_info(
logger, "_configure_subnet: "
"SubnetIds not specified for head node, using {}", subnet_descr)
else:
_set_config_info(head_subnet_src="config")
if "SubnetIds" not in config["worker_nodes"]:
_set_config_info(workers_subnet_src="default")
config["worker_nodes"]["SubnetIds"] = subnet_ids
logger.info("_configure_subnet: "
"SubnetId not specified for workers,"
" using {}".format(subnet_descr))
cli_logger.old_info(
logger, "_configure_subnet: "
"SubnetId not specified for workers,"
" using {}", subnet_descr)
else:
_set_config_info(workers_subnet_src="config")
return config
def _configure_security_group(config):
_set_config_info(
head_security_group_src="config", workers_security_group_src="config")
node_types_to_configure = [
node_type for node_type, config_key in NODE_TYPE_CONFIG_KEYS.items()
if "SecurityGroupIds" not in config[NODE_TYPE_CONFIG_KEYS[node_type]]
@ -255,17 +432,22 @@ def _configure_security_group(config):
if NODE_TYPE_HEAD in node_types_to_configure:
head_sg = security_groups[NODE_TYPE_HEAD]
logger.info(
"_configure_security_group: "
"SecurityGroupIds not specified for head node, using {} ({})"
.format(head_sg.group_name, head_sg.id))
_set_config_info(head_security_group_src="default")
cli_logger.old_info(
logger, "_configure_security_group: "
"SecurityGroupIds not specified for head node, using {} ({})",
head_sg.group_name, head_sg.id)
config["head_node"]["SecurityGroupIds"] = [head_sg.id]
if NODE_TYPE_WORKER in node_types_to_configure:
workers_sg = security_groups[NODE_TYPE_WORKER]
logger.info("_configure_security_group: "
"SecurityGroupIds not specified for workers, using {} ({})"
.format(workers_sg.group_name, workers_sg.id))
_set_config_info(workers_security_group_src="default")
cli_logger.old_info(
logger, "_configure_security_group: "
"SecurityGroupIds not specified for workers, using {} ({})",
workers_sg.group_name, workers_sg.id)
config["worker_nodes"]["SecurityGroupIds"] = [workers_sg.id]
return config
@ -274,6 +456,8 @@ def _configure_security_group(config):
def _check_ami(config):
"""Provide helpful message for missing ImageId for node configuration."""
_set_config_info(head_ami_src="config", workers_ami_src="config")
region = config["provider"]["region"]
default_ami = DEFAULT_AMI.get(region)
if not default_ami:
@ -282,21 +466,27 @@ def _check_ami(config):
if config["head_node"].get("ImageId", "").lower() == "latest_dlami":
config["head_node"]["ImageId"] = default_ami
logger.info("_check_ami: head node ImageId is 'latest_dlami'. "
"Using '{ami_id}', which is the default {ami_name} "
"for your region ({region}).".format(
ami_id=default_ami,
ami_name=DEFAULT_AMI_NAME,
region=region))
_set_config_info(head_ami_src="dlami")
cli_logger.old_info(
logger,
"_check_ami: head node ImageId is 'latest_dlami'. "
"Using '{ami_id}', which is the default {ami_name} "
"for your region ({region}).",
ami_id=default_ami,
ami_name=DEFAULT_AMI_NAME,
region=region)
if config["worker_nodes"].get("ImageId", "").lower() == "latest_dlami":
config["worker_nodes"]["ImageId"] = default_ami
logger.info("_check_ami: worker nodes ImageId is 'latest_dlami'. "
"Using '{ami_id}', which is the default {ami_name} "
"for your region ({region}).".format(
ami_id=default_ami,
ami_name=DEFAULT_AMI_NAME,
region=region))
_set_config_info(workers_ami_src="dlami")
cli_logger.old_info(
logger,
"_check_ami: worker nodes ImageId is 'latest_dlami'. "
"Using '{ami_id}', which is the default {ami_name} "
"for your region ({region}).",
ami_id=default_ami,
ami_name=DEFAULT_AMI_NAME,
region=region)
def _upsert_security_groups(config, node_types):
@ -350,6 +540,9 @@ def _get_vpc_id_or_die(ec2, subnet_id):
"Name": "subnet-id",
"Values": [subnet_id]
}]))
# TODO: better error message
cli_logger.doassert(len(subnet) == 1, "Subnet ID not found: {}", subnet_id)
assert len(subnet) == 1, "Subnet ID not found: {}".format(subnet_id)
subnet = subnet[0]
return subnet.vpc_id
@ -383,8 +576,17 @@ def _create_security_group(config, vpc_id, group_name):
GroupName=group_name,
VpcId=vpc_id)
security_group = _get_security_group(config, vpc_id, group_name)
logger.info("_create_security_group: Created new security group {} ({})"
.format(security_group.group_name, security_group.id))
cli_logger.verbose(
"Created new security group {}",
cf.bold(security_group.group_name),
_tags=dict(id=security_group.id))
cli_logger.old_info(
logger, "_create_security_group: Created new security group {} ({})",
security_group.group_name, security_group.id)
cli_logger.doassert(security_group,
"Failed to create security group") # err msg
assert security_group, "Failed to create security group"
return security_group
@ -454,6 +656,9 @@ def _get_role(role_name, config):
if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
return None
else:
handle_boto_error(
exc, "Failed to fetch IAM role data for {} from AWS.",
cf.bold(role_name))
raise exc
@ -467,17 +672,26 @@ def _get_instance_profile(profile_name, config):
if exc.response.get("Error", {}).get("Code") == "NoSuchEntity":
return None
else:
handle_boto_error(
exc,
"Failed to fetch IAM instance profile data for {} from AWS.",
cf.bold(profile_name))
raise exc
def _get_key(key_name, config):
ec2 = _resource("ec2", config)
for key in ec2.key_pairs.filter(Filters=[{
"Name": "key-name",
"Values": [key_name]
}]):
if key.name == key_name:
return key
try:
for key in ec2.key_pairs.filter(Filters=[{
"Name": "key-name",
"Values": [key_name]
}]):
if key.name == key_name:
return key
except botocore.exceptions.ClientError as exc:
handle_boto_error(exc, "Failed to fetch EC2 key pair {} from AWS.",
cf.bold(key_name))
raise exc
def _client(name, config):

View file

@ -15,6 +15,10 @@ from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, \
from ray.ray_constants import BOTO_MAX_RETRIES, BOTO_CREATE_MAX_RETRIES
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler.aws.utils import boto_exception_handler
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
logger = logging.getLogger(__name__)
@ -133,7 +137,10 @@ class AWSNodeProvider(NodeProvider):
"Values": [v],
})
nodes = list(self.ec2.instances.filter(Filters=filters))
with boto_exception_handler(
"Failed to fetch running instances from AWS."):
nodes = list(self.ec2.instances.filter(Filters=filters))
# Populate the tag cache with initial information if necessary
for node in nodes:
if node.id in self.tag_cache:
@ -228,19 +235,32 @@ class AWSNodeProvider(NodeProvider):
self.ec2.instances.filter(Filters=filters))[:count]
reuse_node_ids = [n.id for n in reuse_nodes]
if reuse_nodes:
logger.info("AWSNodeProvider: reusing instances {}. "
"To disable reuse, set "
"'cache_stopped_nodes: False' in the provider "
"config.".format(reuse_node_ids))
cli_logger.print(
# todo: handle plural vs singular?
"Reusing nodes {}. "
"To disable reuse, set `cache_stopped_nodes: False` "
"under `provider` in the cluster configuration.",
cli_logger.render_list(reuse_node_ids))
cli_logger.old_info(
logger, "AWSNodeProvider: reusing instances {}. "
"To disable reuse, set "
"'cache_stopped_nodes: False' in the provider "
"config.", reuse_node_ids)
for node in reuse_nodes:
self.tag_cache[node.id] = from_aws_format(
{x["Key"]: x["Value"]
for x in node.tags})
if node.state["Name"] == "stopping":
logger.info("AWSNodeProvider: waiting for instance "
"{} to fully stop...".format(node.id))
node.wait_until_stopped()
# todo: timed?
with cli_logger.group("Stopping instances to reuse"):
for node in reuse_nodes:
self.tag_cache[node.id] = from_aws_format(
{x["Key"]: x["Value"]
for x in node.tags})
if node.state["Name"] == "stopping":
cli_logger.print("Waiting for instance {} to stop",
node.id)
cli_logger.old_info(
logger,
"AWSNodeProvider: waiting for instance "
"{} to fully stop...", node.id)
node.wait_until_stopped()
self.ec2.meta.client.start_instances(
InstanceIds=reuse_node_ids)
@ -300,8 +320,11 @@ class AWSNodeProvider(NodeProvider):
for attempt in range(1, BOTO_CREATE_MAX_RETRIES + 1):
try:
subnet_id = subnet_ids[self.subnet_idx % len(subnet_ids)]
logger.info("NodeProvider: calling create_instances "
"with {} (count={}).".format(subnet_id, count))
cli_logger.old_info(
logger, "NodeProvider: calling create_instances "
"with {} (count={}).", subnet_id, count)
self.subnet_idx += 1
conf.update({
"MinCount": 1,
@ -310,32 +333,64 @@ class AWSNodeProvider(NodeProvider):
"TagSpecifications": tag_specs
})
created = self.ec2_fail_fast.create_instances(**conf)
for instance in created:
logger.info("NodeProvider: Created instance "
"[id={}, name={}, info={}]".format(
instance.instance_id,
instance.state["Name"],
instance.state_reason["Message"]))
# todo: timed?
# todo: handle plurality?
with cli_logger.group(
"Launching {} nodes",
count,
_tags=dict(subnet_id=subnet_id)):
for instance in created:
cli_logger.print(
"Launched instance {}",
instance.instance_id,
_tags=dict(
state=instance.state["Name"],
info=instance.state_reason["Message"]))
cli_logger.old_info(
logger, "NodeProvider: Created instance "
"[id={}, name={}, info={}]", instance.instance_id,
instance.state["Name"],
instance.state_reason["Message"])
break
except botocore.exceptions.ClientError as exc:
if attempt == BOTO_CREATE_MAX_RETRIES:
logger.error(
"create_instances: Max attempts ({}) exceeded.".format(
BOTO_CREATE_MAX_RETRIES))
# todo: err msg
cli_logger.abort(
"Failed to launch instances. Max attempts exceeded.")
cli_logger.old_error(
logger,
"create_instances: Max attempts ({}) exceeded.",
BOTO_CREATE_MAX_RETRIES)
raise exc
else:
logger.error(exc)
# todo: err msg
cli_logger.abort(exc)
cli_logger.old_error(logger, exc)
def terminate_node(self, node_id):
node = self._get_cached_node(node_id)
if self.cache_stopped_nodes:
if node.spot_instance_request_id:
logger.info(
cli_logger.print(
"Terminating instance {} " +
cf.gray("(cannot stop spot instances, only terminate)"),
node_id) # todo: show node name?
cli_logger.old_info(
logger,
"AWSNodeProvider: terminating node {} (spot nodes cannot "
"be stopped, only terminated)".format(node_id))
"be stopped, only terminated)", node_id)
node.terminate()
else:
logger.info(
cli_logger.print("Stopping instance {} " + cf.gray(
"(to terminate instead, "
"set `cache_stopped_nodes: False` "
"under `provider` in the cluster configuration)"),
node_id) # todo: show node name?
cli_logger.old_info(
logger,
"AWSNodeProvider: stopping node {}. To terminate nodes "
"on stop, set 'cache_stopped_nodes: False' in the "
"provider config.".format(node_id))
@ -360,15 +415,30 @@ class AWSNodeProvider(NodeProvider):
on_demand_ids += [node_id]
if on_demand_ids:
logger.info(
# todo: show node names?
cli_logger.print(
"Stopping instances {} " + cf.gray(
"(to terminate instead, "
"set `cache_stopped_nodes: False` "
"under `provider` in the cluster configuration)"),
cli_logger.render_list(on_demand_ids))
cli_logger.old_info(
logger,
"AWSNodeProvider: stopping nodes {}. To terminate nodes "
"on stop, set 'cache_stopped_nodes: False' in the "
"provider config.".format(on_demand_ids))
"provider config.", on_demand_ids)
self.ec2.meta.client.stop_instances(InstanceIds=on_demand_ids)
if spot_ids:
logger.info(
cli_logger.print(
"Terminating instances {} " +
cf.gray("(cannot stop spot instances, only terminate)"),
cli_logger.render_list(spot_ids))
cli_logger.old_info(
logger,
"AWSNodeProvider: terminating nodes {} (spot nodes cannot "
"be stopped, only terminated)".format(spot_ids))
"be stopped, only terminated)", spot_ids)
self.ec2.meta.client.terminate_instances(InstanceIds=spot_ids)
else:
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)

View file

@ -1,5 +1,8 @@
from collections import defaultdict
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
class LazyDefaultDict(defaultdict):
"""
@ -21,3 +24,99 @@ class LazyDefaultDict(defaultdict):
"""
self[key] = self.default_factory(key)
return self[key]
def handle_boto_error(exc, msg, *args, **kwargs):
if cli_logger.old_style:
# old-style logging doesn't do anything here
# so we exit early
return
error_code = None
error_info = None
# todo: not sure if these exceptions always have response
if hasattr(exc, "response"):
error_info = exc.response.get("Error", None)
if error_info is not None:
error_code = error_info.get("Code", None)
generic_message_args = [
"{}\n"
"Error code: {}",
msg.format(*args, **kwargs),
cf.bold(error_code)
]
# apparently
# ExpiredTokenException
# ExpiredToken
# RequestExpired
# are all the same pretty much
credentials_expiration_codes = [
"ExpiredTokenException", "ExpiredToken", "RequestExpired"
]
if error_code in credentials_expiration_codes:
# "An error occurred (ExpiredToken) when calling the
# GetInstanceProfile operation: The security token
# included in the request is expired"
# "An error occurred (RequestExpired) when calling the
# DescribeKeyPairs operation: Request has expired."
token_command = (
"aws sts get-session-token "
"--serial-number arn:aws:iam::" + cf.underlined("ROOT_ACCOUNT_ID")
+ ":mfa/" + cf.underlined("AWS_USERNAME") + " --token-code " +
cf.underlined("TWO_FACTOR_AUTH_CODE"))
secret_key_var = (
"export AWS_SECRET_ACCESS_KEY = " + cf.underlined("REPLACE_ME") +
" # found at Credentials.SecretAccessKey")
session_token_var = (
"export AWS_SESSION_TOKEN = " + cf.underlined("REPLACE_ME") +
" # found at Credentials.SessionToken")
access_key_id_var = (
"export AWS_ACCESS_KEY_ID = " + cf.underlined("REPLACE_ME") +
" # found at Credentials.AccessKeyId")
# fixme: replace with a Github URL that points
# to our repo
aws_session_script_url = ("https://gist.github.com/maximsmol/"
"a0284e1d97b25d417bd9ae02e5f450cf")
cli_logger.verbose_error(*generic_message_args)
cli_logger.verbose(vars(exc))
cli_logger.abort(
"Your AWS session has expired.\n\n"
"You can request a new one using\n{}\n"
"then expose it to Ray by setting\n{}\n{}\n{}\n\n"
"You can find a script that automates this at:\n{}",
cf.bold(token_command), cf.bold(secret_key_var),
cf.bold(session_token_var), cf.bold(access_key_id_var),
cf.underlined(aws_session_script_url))
# todo: any other errors that we should catch separately?
cli_logger.error(*generic_message_args)
cli_logger.newline()
with cli_logger.verbatim_error_ctx("Boto3 error:"):
cli_logger.verbose(vars(exc))
cli_logger.error(exc)
cli_logger.abort()
def boto_exception_handler(msg, *args, **kwargs):
# todo: implement timer
class ExceptionHandlerContextManager():
def __enter__(self):
pass
def __exit__(self, type, value, tb):
import botocore
if type is botocore.exceptions.ClientError:
handle_boto_error(value, msg, *args, **kwargs)
return ExceptionHandlerContextManager()

View file

@ -0,0 +1,299 @@
import sys
import click
import colorama
from colorful.core import ColorfulString
import colorful as cf
colorama.init()
def _strip_codes(msg):
return msg # todo
# we could bold "{}" strings automatically but do we want that?
# todo:
def _format_msg(msg,
*args,
_tags=None,
_numbered=None,
_no_format=None,
**kwargs):
if isinstance(msg, str) or isinstance(msg, ColorfulString):
tags_str = ""
if _tags is not None:
tags_list = []
for k, v in _tags.items():
if v is True:
tags_list += [k]
continue
if v is False:
continue
tags_list += [k + "=" + v]
if tags_list:
tags_str = cf.reset(
cf.gray(" [{}]".format(", ".join(tags_list))))
numbering_str = ""
if _numbered is not None:
chars, i, n = _numbered
i = str(i)
n = str(n)
numbering_str = cf.gray(chars[0] + i + "/" + n + chars[1]) + " "
if _no_format:
# todo: throw if given args/kwargs?
return numbering_str + msg + tags_str
return numbering_str + msg.format(*args, **kwargs) + tags_str
if kwargs:
raise ValueError("We do not support printing kwargs yet.")
res = [msg, *args]
res = [str(x) for x in res]
return ", ".join(res)
class _CliLogger():
def __init__(self):
self.strip = False
self.old_style = True
self.color_mode = "auto"
self.indent_level = 0
self.verbosity = 0
self.dump_command_output = False
self.info = {}
def detect_colors(self):
if self.color_mode == "true":
self.strip = False
return
if self.color_mode == "false":
self.strip = True
return
if self.color_mode == "auto":
self.strip = sys.stdout.isatty()
return
raise ValueError("Invalid log color setting: " + self.color_mode)
def newline(self):
self._print("")
def _print(self, msg, linefeed=True):
if self.old_style:
return
if self.strip:
msg = _strip_codes(msg)
rendered_message = " " * self.indent_level + msg
if not linefeed:
sys.stdout.write(rendered_message)
sys.stdout.flush()
return
print(rendered_message)
def indented(self, cls=False):
cli_logger = self
class IndentedContextManager():
def __enter__(self):
cli_logger.indent_level += 1
def __exit__(self, type, value, tb):
cli_logger.indent_level -= 1
if cls:
# fixme: this does not work :()
return IndentedContextManager
return IndentedContextManager()
def timed(self, msg, *args, **kwargs):
return self.group(msg, *args, **kwargs)
def group(self, msg, *args, **kwargs):
self._print(_format_msg(cf.cornflowerBlue(msg), *args, **kwargs))
return self.indented()
def verbatim_error_ctx(self, msg, *args, **kwargs):
cli_logger = self
class VerbatimErorContextManager():
def __enter__(self):
cli_logger.error(cf.bold("!!! ") + msg, *args, **kwargs)
def __exit__(self, type, value, tb):
cli_logger.error(cf.bold("!!!"))
return VerbatimErorContextManager()
def labeled_value(self, key, msg, *args, **kwargs):
self._print(
cf.cyan(key) + ": " + _format_msg(cf.bold(msg), *args, **kwargs))
def verbose(self, msg, *args, **kwargs):
if self.verbosity > 0:
self.print(msg, *args, **kwargs)
def verbose_error(self, msg, *args, **kwargs):
if self.verbosity > 0:
self.error(msg, *args, **kwargs)
def very_verbose(self, msg, *args, **kwargs):
if self.verbosity > 1:
self.print(msg, *args, **kwargs)
def success(self, msg, *args, **kwargs):
self._print(_format_msg(cf.green(msg), *args, **kwargs))
def warning(self, msg, *args, **kwargs):
self._print(_format_msg(cf.yellow(msg), *args, **kwargs))
def error(self, msg, *args, **kwargs):
self._print(_format_msg(cf.red(msg), *args, **kwargs))
def print(self, msg, *args, **kwargs):
self._print(_format_msg(msg, *args, **kwargs))
def abort(self, msg=None, *args, **kwargs):
if msg is not None:
self.error(msg, *args, **kwargs)
raise SilentClickException("Exiting due to cli_logger.abort()")
def doassert(self, val, msg, *args, **kwargs):
if self.old_style:
return
if not val:
self.abort(msg, *args, **kwargs)
def old_debug(self, logger, msg, *args, **kwargs):
if self.old_style:
logger.debug(_format_msg(msg, *args, **kwargs))
return
def old_info(self, logger, msg, *args, **kwargs):
if self.old_style:
logger.info(_format_msg(msg, *args, **kwargs))
return
def old_warning(self, logger, msg, *args, **kwargs):
if self.old_style:
logger.warning(_format_msg(msg, *args, **kwargs))
return
def old_error(self, logger, msg, *args, **kwargs):
if self.old_style:
logger.error(_format_msg(msg, *args, **kwargs))
return
def old_exception(self, logger, msg, *args, **kwargs):
if self.old_style:
logger.exception(_format_msg(msg, *args, **kwargs))
return
def render_list(self, xs, separator=cf.reset(", ")):
return separator.join([str(cf.bold(x)) for x in xs])
def confirm(self, yes, msg, *args, _abort=False, _default=False, **kwargs):
if self.old_style:
return
should_abort = _abort
default = _default
if default:
yn_str = cf.green("Y") + "/" + cf.red("n")
else:
yn_str = cf.green("y") + "/" + cf.red("N")
confirm_str = cf.underlined("Confirm [" + yn_str + "]:") + " "
rendered_message = _format_msg(msg, *args, **kwargs)
if rendered_message and rendered_message[-1] != "\n":
rendered_message += " "
msg_len = len(rendered_message.split("\n")[-1])
complete_str = rendered_message + confirm_str
if yes:
self._print(complete_str + "y " +
cf.gray("[automatic, due to --yes]"))
return True
self._print(complete_str, linefeed=False)
res = None
yes_answers = ["y", "yes", "true", "1"]
no_answers = ["n", "no", "false", "0"]
try:
while True:
ans = sys.stdin.readline()
ans = ans.lower()
if ans == "\n":
res = default
break
ans = ans.strip()
if ans in yes_answers:
res = True
break
if ans in no_answers:
res = False
break
indent = " " * msg_len
self.error("{}Invalid answer: {}. "
"Expected {} or {}", indent, cf.bold(ans.strip()),
self.render_list(yes_answers, "/"),
self.render_list(no_answers, "/"))
self._print(indent + confirm_str, linefeed=False)
except KeyboardInterrupt:
self.newline()
res = default
if not res and should_abort:
# todo: make sure we tell the user if they
# need to do cleanup
self._print("Exiting...")
raise SilentClickException(
"Exiting due to the response to confirm(should_abort=True).")
return res
def old_confirm(self, msg, yes):
if not self.old_style:
return
return None if yes else click.confirm(msg, abort=True)
class SilentClickException(click.ClickException):
"""
Some of our tooling relies on catching ClickException in particular.
However the default prints a message, which is undesirable since we expect
our code to log errors manually using `cli_logger.error()` to allow for
colors and other formatting.
"""
def __init__(self, message):
super(SilentClickException, self).__init__(message)
def show(self, file=None):
pass
cli_logger = _CliLogger()

View file

@ -12,6 +12,9 @@ import time
from ray.autoscaler.docker import check_docker_running_cmd, with_docker_exec
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
logger = logging.getLogger(__name__)
# How long to wait for a node to start, in seconds
@ -21,6 +24,24 @@ KUBECTL_RSYNC = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
class ProcessRunnerError(Exception):
def __init__(self,
msg,
msg_type,
code=None,
command=None,
message_discovered=None):
super(ProcessRunnerError, self).__init__(
"{} (discovered={}): type={}, code={}, command={}".format(
msg, message_discovered, msg_type, code, command))
self.msg_type = msg_type
self.code = code
self.command = command
self.message_discovered = message_discovered
def _with_interactive(cmd):
force_interactive = ("true && source ~/.bashrc && "
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
@ -256,14 +277,27 @@ class SSHCommandRunner(CommandRunnerInterface):
else:
return self.provider.external_ip(self.node_id)
def _wait_for_ip(self, deadline):
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
logger.info(self.log_prefix + "Waiting for IP...")
ip = self._get_node_ip()
if ip is not None:
return ip
time.sleep(10)
def wait_for_ip(self, deadline):
# if we have IP do not print waiting info
ip = self._get_node_ip()
if ip is not None:
cli_logger.labeled_value("Fetched IP", ip)
return ip
interval = 10
with cli_logger.timed("Waiting for IP"):
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
cli_logger.old_info(logger, "{}Waiting for IP...",
self.log_prefix)
ip = self._get_node_ip()
if ip is not None:
cli_logger.labeled_value("Received", ip)
return ip
cli_logger.print("Not yet available, retrying in {} seconds",
cf.bold(str(interval)))
time.sleep(interval)
return None
@ -275,7 +309,10 @@ class SSHCommandRunner(CommandRunnerInterface):
# I think that's reasonable.
deadline = time.time() + NODE_START_WAIT_S
with LogTimer(self.log_prefix + "Got IP"):
ip = self._wait_for_ip(deadline)
ip = self.wait_for_ip(deadline)
cli_logger.doassert(ip is not None,
"Could not get node IP.") # todo: msg
assert ip is not None, "Unable to find IP of node"
self.ssh_ip = ip
@ -286,7 +323,8 @@ class SSHCommandRunner(CommandRunnerInterface):
try:
os.makedirs(self.ssh_control_path, mode=0o700, exist_ok=True)
except OSError as e:
logger.warning(e)
cli_logger.warning(e) # todo: msg
cli_logger.old_warning(logger, e)
def run(self,
cmd,
@ -308,59 +346,91 @@ class SSHCommandRunner(CommandRunnerInterface):
ssh = ["ssh", "-tt"]
if port_forward:
if not isinstance(port_forward, list):
port_forward = [port_forward]
for local, remote in port_forward:
logger.info(self.log_prefix + "Forwarding " +
"{} -> localhost:{}".format(local, remote))
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
with cli_logger.group("Forwarding ports"):
if not isinstance(port_forward, list):
port_forward = [port_forward]
for local, remote in port_forward:
cli_logger.verbose(
"Forwarding port {} to port {} on localhost.",
cf.bold(local), cf.bold(remote)) # todo: msg
cli_logger.old_info(logger,
"{}Forwarding {} -> localhost:{}",
self.log_prefix, local, remote)
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
"{}@{}".format(self.ssh_user, self.ssh_ip)
]
if cmd:
final_cmd += _with_interactive(cmd)
logger.info(self.log_prefix +
"Running {}".format(" ".join(final_cmd)))
cli_logger.old_info(logger, "{}Running {}", self.log_prefix,
" ".join(final_cmd))
else:
# We do this because `-o ControlMaster` causes the `-N` flag to
# still create an interactive shell in some ssh versions.
final_cmd.append(quote("while true; do sleep 86400; done"))
try:
if with_output:
return self.process_runner.check_output(final_cmd)
else:
self.process_runner.check_call(final_cmd)
except subprocess.CalledProcessError:
if exit_on_fail:
# todo: add a flag for this, we might
# wanna log commands with print sometimes
cli_logger.verbose("Running `{}`", cf.bold(cmd))
with cli_logger.indented():
cli_logger.very_verbose("Full command is `{}`",
cf.bold(" ".join(final_cmd)))
def start_process():
try:
if with_output:
return self.process_runner.check_output(final_cmd)
else:
self.process_runner.check_call(final_cmd)
except subprocess.CalledProcessError as e:
quoted_cmd = " ".join(final_cmd[:-1] + [quote(final_cmd[-1])])
raise click.ClickException(
"Command failed: \n\n {}\n".format(quoted_cmd)) from None
else:
raise click.ClickException(
"SSH command Failed. See above for the output from the"
" failure.") from None
if not cli_logger.old_style:
raise ProcessRunnerError(
"Command failed",
"ssh_command_failed",
code=e.returncode,
command=quoted_cmd)
if exit_on_fail:
raise click.ClickException(
"Command failed: \n\n {}\n".format(quoted_cmd)) \
from None
else:
raise click.ClickException(
"SSH command Failed. See above for the output from the"
" failure.") from None
if cli_logger.verbosity > 0:
with cli_logger.indented():
start_process()
else:
start_process()
def run_rsync_up(self, source, target):
self._set_ssh_ip_if_required()
self.process_runner.check_call([
command = [
"rsync", "--rsh",
subprocess.list2cmdline(
["ssh"] + self.ssh_options.to_ssh_options_list(timeout=120)),
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
target)
])
]
cli_logger.verbose("Running `{}`", cf.bold(" ".join(command)))
self.process_runner.check_call(command)
def run_rsync_down(self, source, target):
self._set_ssh_ip_if_required()
self.process_runner.check_call([
command = [
"rsync", "--rsh",
subprocess.list2cmdline(
["ssh"] + self.ssh_options.to_ssh_options_list(timeout=120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
])
]
cli_logger.verbose("Running `{}`", cf.bold(" ".join(command)))
self.process_runner.check_call(command)
def remote_shell_command_str(self):
return "ssh -o IdentitiesOnly=yes -i {} {}@{}\n".format(

View file

@ -21,7 +21,9 @@ import ray.services as services
from ray.autoscaler.util import validate_config, hash_runtime_conf, \
hash_launch_conf, prepare_config, DEBUG_AUTOSCALING_ERROR, \
DEBUG_AUTOSCALING_STATUS
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS
from ray.autoscaler.node_provider import get_node_provider, NODE_PROVIDERS, \
PROVIDER_PRETTY_NAMES, try_get_log_state, try_logging_config, \
try_reload_log_state
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, NODE_TYPE_HEAD
@ -31,6 +33,9 @@ from ray.autoscaler.command_runner import DockerCommandRunner
from ray.autoscaler.log_timer import LogTimer
from ray.worker import global_worker
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
logger = logging.getLogger(__name__)
redis_client = None
@ -89,20 +94,93 @@ def create_or_update_cluster(
config_file: str, override_min_workers: Optional[int],
override_max_workers: Optional[int], no_restart: bool,
restart_only: bool, yes: bool, override_cluster_name: Optional[str],
no_config_cache: bool) -> None:
no_config_cache: bool, log_old_style: bool, log_color: str,
verbose: int) -> None:
"""Create or updates an autoscaling Ray cluster from a config json."""
config = yaml.safe_load(open(config_file).read())
if override_min_workers is not None:
config["min_workers"] = override_min_workers
if override_max_workers is not None:
config["max_workers"] = override_max_workers
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name
cli_logger.old_style = log_old_style
cli_logger.color_mode = log_color
cli_logger.verbosity = verbose
# todo: disable by default when the command output handling PR makes it in
cli_logger.dump_command_output = True
cli_logger.detect_colors()
def handle_yaml_error(e):
cli_logger.error(
"Cluster config invalid.\n"
"Failed to load YAML file " + cf.bold("{}"), config_file)
cli_logger.newline()
with cli_logger.verbatim_error_ctx("PyYAML error:"):
cli_logger.error(e)
cli_logger.abort()
try:
config = yaml.safe_load(open(config_file).read())
except FileNotFoundError:
cli_logger.abort(
"Provided cluster configuration file ({}) does not exist.",
cf.bold(config_file))
except yaml.parser.ParserError as e:
handle_yaml_error(e)
except yaml.scanner.ScannerError as e:
handle_yaml_error(e)
# todo: validate file_mounts, ssh keys, etc.
importer = NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
cli_logger.abort(
"Unknown provider type " + cf.bold("{}") + "\n"
"Available providers are: {}", config["provider"]["type"],
cli_logger.render_list([
k for k in NODE_PROVIDERS.keys()
if NODE_PROVIDERS[k] is not None
]))
raise NotImplementedError("Unsupported provider {}".format(
config["provider"]))
cli_logger.success("Cluster configuration valid.\n")
printed_overrides = False
def handle_cli_override(key, override):
if override is not None:
if key in config:
nonlocal printed_overrides
printed_overrides = True
cli_logger.warning(
"`{}` override provided on the command line.\n"
" Using " + cf.bold("{}") + cf.dimmed(
" [configuration file has " + cf.bold("{}") + "]"),
key, override, config[key])
config[key] = override
handle_cli_override("min_workers", override_min_workers)
handle_cli_override("max_workers", override_max_workers)
handle_cli_override("cluster_name", override_cluster_name)
if printed_overrides:
cli_logger.newline()
cli_logger.labeled_value("Cluster", config["cluster_name"])
# disable the cli_logger here if needed
# because it only supports aws
if config["provider"]["type"] != "aws":
cli_logger.old_style = True
config = _bootstrap_config(config, no_config_cache)
if config["provider"]["type"] != "aws":
cli_logger.old_style = False
try_logging_config(config)
get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
override_cluster_name)
CONFIG_CACHE_VERSION = 1
def _bootstrap_config(config: Dict[str, Any],
no_config_cache: bool = False) -> Dict[str, Any]:
config = prepare_config(config)
@ -111,9 +189,30 @@ def _bootstrap_config(config: Dict[str, Any],
hasher.update(json.dumps([config], sort_keys=True).encode("utf-8"))
cache_key = os.path.join(tempfile.gettempdir(),
"ray-config-{}".format(hasher.hexdigest()))
if os.path.exists(cache_key) and not no_config_cache:
logger.info("Using cached config at {}".format(cache_key))
return json.loads(open(cache_key).read())
cli_logger.old_info(logger, "Using cached config at {}", cache_key)
config_cache = json.loads(open(cache_key).read())
if config_cache.get("_version", -1) == CONFIG_CACHE_VERSION:
# todo: is it fine to re-resolve? afaik it should be.
# we can have migrations otherwise or something
# but this seems overcomplicated given that resolving is
# relatively cheap
try_reload_log_state(config_cache["config"]["provider"],
config_cache.get("provider_log_info"))
cli_logger.verbose("Loaded cached config from " + cf.bold("{}"),
cache_key)
return config_cache["config"]
else:
cli_logger.warning(
"Found cached cluster config "
"but the version " + cf.bold("{}") + " "
"(expected " + cf.bold("{}") + ") does not match.\n"
"This is normal if cluster launcher was updated.\n"
"Config will be re-resolved.",
config_cache.get("_version", "none"), CONFIG_CACHE_VERSION)
validate_config(config)
importer = NODE_PROVIDERS.get(config["provider"]["type"])
@ -122,17 +221,32 @@ def _bootstrap_config(config: Dict[str, Any],
config["provider"]))
provider_cls = importer(config["provider"])
resolved_config = provider_cls.bootstrap_config(config)
with cli_logger.timed( # todo: better message
"Bootstraping {} config",
PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])):
resolved_config = provider_cls.bootstrap_config(config)
if not no_config_cache:
with open(cache_key, "w") as f:
f.write(json.dumps(resolved_config))
config_cache = {
"_version": CONFIG_CACHE_VERSION,
"provider_log_info": try_get_log_state(config["provider"]),
"config": resolved_config
}
f.write(json.dumps(config_cache))
return resolved_config
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
override_cluster_name: Optional[str],
keep_min_workers: bool):
keep_min_workers: bool, log_old_style: bool,
log_color: str, verbose: int):
"""Destroys all nodes of a Ray cluster described by a config json."""
cli_logger.old_style = log_old_style
cli_logger.color_mode = log_color
cli_logger.verbosity = verbose
cli_logger.dump_command_output = verbose == 3 # todo: add a separate flag?
config = yaml.safe_load(open(config_file).read())
if override_cluster_name is not None:
@ -140,7 +254,8 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
config = prepare_config(config)
validate_config(config)
confirm("This will destroy your cluster", yes)
cli_logger.confirm(yes, "Destroying cluster.", _abort=True)
cli_logger.old_confirm("This will destroy your cluster", yes)
if not workers_only:
try:
@ -155,8 +270,17 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
override_cluster_name=override_cluster_name,
port_forward=None,
with_output=False)
except Exception:
logger.exception("Ignoring error attempting a clean shutdown.")
except Exception as e:
cli_logger.verbose_error(e) # todo: add better exception info
cli_logger.warning(
"Exception occured when stopping the cluster Ray runtime "
"(use -v to dump teardown exceptions).")
cli_logger.warning(
"Ignoring the exception and "
"attempting to shut down the cluster nodes anyway.")
cli_logger.old_exception(
logger, "Ignoring error attempting a clean shutdown.")
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
@ -169,11 +293,23 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
if keep_min_workers:
min_workers = config.get("min_workers", 0)
logger.info("teardown_cluster: "
"Keeping {} nodes...".format(min_workers))
cli_logger.print(
"{} random worker nodes will not be shut down. " +
cf.gray("(due to {})"), cf.bold(min_workers),
cf.bold("--keep-min-workers"))
cli_logger.old_info(logger,
"teardown_cluster: Keeping {} nodes...",
min_workers)
workers = random.sample(workers, len(workers) - min_workers)
# todo: it's weird to kill the head node but not all workers
if workers_only:
cli_logger.print(
"The head node will not be shut down. " +
cf.gray("(due to {})"), cf.bold("--workers-only"))
return workers
head = provider.non_terminated_nodes({
@ -187,11 +323,21 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
A = remaining_nodes()
with LogTimer("teardown_cluster: done."):
while A:
logger.info("teardown_cluster: "
"Shutting down {} nodes...".format(len(A)))
cli_logger.old_info(
logger, "teardown_cluster: "
"Shutting down {} nodes...", len(A))
provider.terminate_nodes(A)
time.sleep(1)
cli_logger.print(
"Requested {} nodes to shut down.",
cf.bold(len(A)),
_tags=dict(interval="1s"))
time.sleep(1) # todo: interval should be a variable
A = remaining_nodes()
cli_logger.print("{} nodes remaining after 1 second.",
cf.bold(len(A)))
finally:
provider.cleanup()
@ -261,12 +407,23 @@ def monitor_cluster(cluster_config_file, num_lines, override_cluster_name):
def warn_about_bad_start_command(start_commands):
ray_start_cmd = list(filter(lambda x: "ray start" in x, start_commands))
if len(ray_start_cmd) == 0:
logger.warning(
cli_logger.warning(
"Ray runtime will not be started because `{}` is not in `{}`.",
cf.bold("ray start"), cf.bold("head_start_ray_commands"))
cli_logger.old_warning(
logger,
"Ray start is not included in the head_start_ray_commands section."
)
if not any("autoscaling-config" in x for x in ray_start_cmd):
logger.warning(
"Ray start on the head node does not have the flag"
cli_logger.warning(
"The head node will not launch any workers because "
"`{}` does not have `{}` set.\n"
"Potential fix: add `{}` to the `{}` command under `{}`.",
cf.bold("ray start"), cf.bold("--autoscaling-config"),
cf.bold("--autoscaling-config=~/ray_bootstrap_config.yaml"),
cf.bold("ray start"), cf.bold("head_start_ray_commands"))
logger.old_warning(
logger, "Ray start on the head node does not have the flag"
"--autoscaling-config set. The head node will not launch"
"workers. Add --autoscaling-config=~/ray_bootstrap_config.yaml"
"to ray start in the head_start_ray_commands section.")
@ -288,129 +445,205 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
head_node = None
if not head_node:
confirm("This will create a new cluster", yes)
cli_logger.confirm(
yes,
"No head node found. "
"Launching a new cluster.",
_abort=True)
cli_logger.old_confirm("This will create a new cluster", yes)
elif not no_restart:
confirm("This will restart cluster services", yes)
cli_logger.old_confirm("This will restart cluster services", yes)
if head_node:
if restart_only:
cli_logger.confirm(
yes,
"Updating cluster configuration and "
"restarting the cluster Ray runtime. "
"Setup commands will not be run due to `{}`.\n",
cf.bold("--restart-only"),
_abort=True)
elif no_restart:
cli_logger.print(
"Cluster Ray runtime will not be restarted due "
"to `{}`.", cf.bold("--no-restart"))
cli_logger.confirm(
yes,
"Updating cluster configuration and "
"running setup commands.",
_abort=True)
else:
cli_logger.print(
"Updating cluster configuration and running full setup.")
cli_logger.confirm(
yes,
cf.bold("Cluster Ray runtime will be restarted."),
_abort=True)
cli_logger.newline()
launch_hash = hash_launch_conf(config["head_node"], config["auth"])
if head_node is None or provider.node_tags(head_node).get(
TAG_RAY_LAUNCH_CONFIG) != launch_hash:
if head_node is not None:
confirm("Head node config out-of-date. It will be terminated",
with cli_logger.group("Acquiring an up-to-date head node"):
if head_node is not None:
cli_logger.print(
"Currently running head node is out-of-date with "
"cluster configuration")
cli_logger.print(
"hash is {}, expected {}",
cf.bold(
provider.node_tags(head_node)
.get(TAG_RAY_LAUNCH_CONFIG)), cf.bold(launch_hash))
cli_logger.confirm(yes, "Relaunching it.", _abort=True)
cli_logger.old_confirm(
"Head node config out-of-date. It will be terminated",
yes)
logger.info(
"get_or_create_head_node: "
"Shutting down outdated head node {}".format(head_node))
provider.terminate_node(head_node)
logger.info("get_or_create_head_node: Launching new head node...")
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
config["cluster_name"])
provider.create_node(config["head_node"], head_node_tags, 1)
start = time.time()
head_node = None
while True:
if time.time() - start > 50:
raise RuntimeError("Failed to create head node.")
nodes = provider.non_terminated_nodes(head_node_tags)
if len(nodes) == 1:
head_node = nodes[0]
break
time.sleep(1)
cli_logger.old_info(
logger, "get_or_create_head_node: "
"Shutting down outdated head node {}", head_node)
# TODO(ekl) right now we always update the head node even if the hash
# matches. We could prompt the user for what they want to do here.
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
logger.info("get_or_create_head_node: Updating files on head node...")
provider.terminate_node(head_node)
cli_logger.print("Terminated head node {}", head_node)
# Rewrite the auth config so that the head node can update the workers
remote_config = copy.deepcopy(config)
# drop proxy options if they exist, otherwise
# head node won't be able to connect to workers
remote_config["auth"].pop("ssh_proxy_command", None)
if config["provider"]["type"] != "kubernetes":
remote_key_path = "~/ray_bootstrap_key.pem"
remote_config["auth"]["ssh_private_key"] = remote_key_path
cli_logger.old_info(
logger,
"get_or_create_head_node: Launching new head node...")
# Adjust for new file locations
new_mounts = {}
for remote_path in config["file_mounts"]:
new_mounts[remote_path] = remote_path
remote_config["file_mounts"] = new_mounts
remote_config["no_restart"] = no_restart
head_node_tags[TAG_RAY_LAUNCH_CONFIG] = launch_hash
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
config["cluster_name"])
provider.create_node(config["head_node"], head_node_tags, 1)
cli_logger.print("Launched a new head node")
# Now inject the rewritten config and SSH key into the head node
remote_config_file = tempfile.NamedTemporaryFile(
"w", prefix="ray-bootstrap-")
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
"~/ray_bootstrap_config.yaml": remote_config_file.name
})
if config["provider"]["type"] != "kubernetes":
start = time.time()
head_node = None
with cli_logger.timed("Fetching the new head node"):
while True:
if time.time() - start > 50:
cli_logger.abort(
"Head node fetch timed out.") # todo: msg
raise RuntimeError("Failed to create head node.")
nodes = provider.non_terminated_nodes(head_node_tags)
if len(nodes) == 1:
head_node = nodes[0]
break
time.sleep(1)
cli_logger.newline()
with cli_logger.group(
"Setting up head node",
_numbered=("<>", 1, 1),
# cf.bold(provider.node_tags(head_node)[TAG_RAY_NODE_NAME]),
_tags=dict()): # add id, ARN to tags?
# TODO(ekl) right now we always update the head node even if the
# hash matches.
# We could prompt the user for what they want to do here.
runtime_hash = hash_runtime_conf(config["file_mounts"], config)
cli_logger.old_info(
logger,
"get_or_create_head_node: Updating files on head node...")
# Rewrite the auth config so that the head
# node can update the workers
remote_config = copy.deepcopy(config)
# drop proxy options if they exist, otherwise
# head node won't be able to connect to workers
remote_config["auth"].pop("ssh_proxy_command", None)
if config["provider"]["type"] != "kubernetes":
remote_key_path = "~/ray_bootstrap_key.pem"
remote_config["auth"]["ssh_private_key"] = remote_key_path
# Adjust for new file locations
new_mounts = {}
for remote_path in config["file_mounts"]:
new_mounts[remote_path] = remote_path
remote_config["file_mounts"] = new_mounts
remote_config["no_restart"] = no_restart
# Now inject the rewritten config and SSH key into the head node
remote_config_file = tempfile.NamedTemporaryFile(
"w", prefix="ray-bootstrap-")
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml": remote_config_file.name
})
if restart_only:
init_commands = []
ray_start_commands = config["head_start_ray_commands"]
elif no_restart:
init_commands = config["head_setup_commands"]
ray_start_commands = []
else:
init_commands = config["head_setup_commands"]
ray_start_commands = config["head_start_ray_commands"]
if config["provider"]["type"] != "kubernetes":
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
})
cli_logger.print("Prepared bootstrap config")
if not no_restart:
warn_about_bad_start_command(ray_start_commands)
if restart_only:
init_commands = []
ray_start_commands = config["head_start_ray_commands"]
elif no_restart:
init_commands = config["head_setup_commands"]
ray_start_commands = []
else:
init_commands = config["head_setup_commands"]
ray_start_commands = config["head_start_ray_commands"]
updater = NodeUpdaterThread(
node_id=head_node,
provider_config=config["provider"],
provider=provider,
auth_config=config["auth"],
cluster_name=config["cluster_name"],
file_mounts=config["file_mounts"],
initialization_commands=config["initialization_commands"],
setup_commands=init_commands,
ray_start_commands=ray_start_commands,
runtime_hash=runtime_hash,
docker_config=config.get("docker"))
updater.start()
updater.join()
if not no_restart:
warn_about_bad_start_command(ray_start_commands)
# Refresh the node cache so we see the external ip if available
provider.non_terminated_nodes(head_node_tags)
updater = NodeUpdaterThread(
node_id=head_node,
provider_config=config["provider"],
provider=provider,
auth_config=config["auth"],
cluster_name=config["cluster_name"],
file_mounts=config["file_mounts"],
initialization_commands=config["initialization_commands"],
setup_commands=init_commands,
ray_start_commands=ray_start_commands,
runtime_hash=runtime_hash,
docker_config=config.get("docker"))
updater.start()
updater.join()
if config.get("provider", {}).get("use_internal_ips", False) is True:
head_node_ip = provider.internal_ip(head_node)
else:
head_node_ip = provider.external_ip(head_node)
# Refresh the node cache so we see the external ip if available
provider.non_terminated_nodes(head_node_tags)
if updater.exitcode != 0:
logger.error("get_or_create_head_node: "
"Updating {} failed".format(head_node_ip))
sys.exit(1)
logger.info(
"get_or_create_head_node: "
"Head node up-to-date, IP address is: {}".format(head_node_ip))
if config.get("provider", {}).get("use_internal_ips",
False) is True:
head_node_ip = provider.internal_ip(head_node)
else:
head_node_ip = provider.external_ip(head_node)
monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
if override_cluster_name:
modifiers = " --cluster-name={}".format(
quote(override_cluster_name))
else:
modifiers = ""
print("To monitor auto-scaling activity, you can run:\n\n"
" ray exec {} {}{}\n".format(config_file, quote(monitor_str),
modifiers))
print("To open a console on the cluster:\n\n"
" ray attach {}{}\n".format(config_file, modifiers))
if updater.exitcode != 0:
# todo: this does not follow the mockup and is not good enough
cli_logger.abort("Failed to setup head node.")
print("To get a remote shell to the cluster manually, run:\n\n"
" {}\n".format(updater.cmd_runner.remote_shell_command_str()))
cli_logger.old_error(
logger, "get_or_create_head_node: "
"Updating {} failed", head_node_ip)
sys.exit(1)
logger.info(
"get_or_create_head_node: "
"Head node up-to-date, IP address is: {}".format(head_node_ip))
monitor_str = "tail -n 100 -f /tmp/ray/session_*/logs/monitor*"
if override_cluster_name:
modifiers = " --cluster-name={}".format(
quote(override_cluster_name))
else:
modifiers = ""
print("To monitor auto-scaling activity, you can run:\n\n"
" ray exec {} {}{}\n".format(config_file,
quote(monitor_str), modifiers))
print("To open a console on the cluster:\n\n"
" ray attach {}{}\n".format(config_file, modifiers))
print("To get a remote shell to the cluster manually, run:\n\n"
" {}\n".format(
updater.cmd_runner.remote_shell_command_str()))
finally:
provider.cleanup()

View file

@ -1,6 +1,8 @@
import datetime
import logging
from ray.autoscaler.cli_logger import cli_logger
logger = logging.getLogger(__name__)
@ -13,6 +15,9 @@ class LogTimer:
self._start_time = datetime.datetime.utcnow()
def __exit__(self, *error_vals):
if not cli_logger.old_style:
return
td = datetime.datetime.utcnow() - self._start_time
status = ""
if self._show_status:

View file

@ -77,6 +77,15 @@ NODE_PROVIDERS = {
"docker": None,
"external": import_external # Import an external module
}
PROVIDER_PRETTY_NAMES = {
"local": "Local",
"aws": "AWS",
"gcp": "GCP",
"azure": "Azure",
"kubernetes": "Kubernetes",
# "docker": "Docker", # not supported
"external": "External"
}
DEFAULT_CONFIGS = {
"local": load_local_example_config,
@ -88,6 +97,26 @@ DEFAULT_CONFIGS = {
}
def try_logging_config(config):
if config["provider"]["type"] == "aws":
from ray.autoscaler.aws.config import log_to_cli
log_to_cli(config)
def try_get_log_state(provider_config):
if provider_config["type"] == "aws":
from ray.autoscaler.aws.config import get_log_state
return get_log_state()
def try_reload_log_state(provider_config, log_state):
if not log_state:
return
if provider_config["type"] == "aws":
from ray.autoscaler.aws.config import reload_log_state
return reload_log_state(log_state)
def load_class(path):
"""
Load a class at runtime given a full path.

View file

@ -12,6 +12,9 @@ from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, \
from ray.autoscaler.command_runner import NODE_START_WAIT_S, SSHOptions
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler.cli_logger import cli_logger
import colorful as cf
logger = logging.getLogger(__name__)
READY_CHECK_INTERVAL = 5
@ -57,8 +60,9 @@ class NodeUpdater:
self.auth_config = auth_config
def run(self):
logger.info(self.log_prefix +
"Updating to {}".format(self.runtime_hash))
cli_logger.old_info(logger, "{}Updating to {}", self.log_prefix,
self.runtime_hash)
try:
with LogTimer(self.log_prefix +
"Applied config {}".format(self.runtime_hash)):
@ -68,11 +72,28 @@ class NodeUpdater:
if hasattr(e, "cmd"):
error_str = "(Exit Status {}) {}".format(
e.returncode, " ".join(e.cmd))
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
logger.error(self.log_prefix +
"Error executing: {}".format(error_str) + "\n")
cli_logger.error("New status: {}", cf.bold(STATUS_UPDATE_FAILED))
cli_logger.old_error(logger, "{}Error executing: {}\n",
self.log_prefix, error_str)
cli_logger.error("!!!")
if hasattr(e, "cmd"):
cli_logger.error(
"Setup command `{}` failed with exit code {}. stderr:",
cf.bold(e.cmd), e.returncode)
else:
cli_logger.verbose_error(vars(e), _no_format=True)
cli_logger.error(str(e)) # todo: handle this better somehow?
# todo: print stderr here
cli_logger.error("!!!")
cli_logger.newline()
if isinstance(e, click.ClickException):
# todo: why do we ignore this here
return
raise
@ -81,98 +102,173 @@ class NodeUpdater:
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_RUNTIME_CONFIG: self.runtime_hash
})
cli_logger.labeled_value("New status", STATUS_UP_TO_DATE)
self.exitcode = 0
def sync_file_mounts(self, sync_cmd):
# Rsync file mounts
for remote_path, local_path in self.file_mounts.items():
assert os.path.exists(local_path), local_path
if os.path.isdir(local_path):
if not local_path.endswith("/"):
local_path += "/"
if not remote_path.endswith("/"):
remote_path += "/"
nolog_paths = []
if cli_logger.verbosity == 0:
nolog_paths = [
"~/ray_bootstrap_key.pem", "~/ray_bootstrap_config.yaml"
]
with LogTimer(self.log_prefix +
"Synced {} to {}".format(local_path, remote_path)):
self.cmd_runner.run("mkdir -p {}".format(
os.path.dirname(remote_path)))
sync_cmd(local_path, remote_path)
# Rsync file mounts
with cli_logger.group(
"Processing file mounts", _numbered=("[]", 2, 5)):
for remote_path, local_path in self.file_mounts.items():
assert os.path.exists(local_path), local_path
if os.path.isdir(local_path):
if not local_path.endswith("/"):
local_path += "/"
if not remote_path.endswith("/"):
remote_path += "/"
with LogTimer(self.log_prefix + "Synced {} to {}".format(
local_path, remote_path)):
self.cmd_runner.run("mkdir -p {}".format(
os.path.dirname(remote_path)))
sync_cmd(local_path, remote_path)
if remote_path not in nolog_paths:
# todo: timed here?
cli_logger.print("{} from {}", cf.bold(remote_path),
cf.bold(local_path))
def wait_ready(self, deadline):
with LogTimer(self.log_prefix + "Got remote shell"):
logger.info(self.log_prefix + "Waiting for remote shell...")
with cli_logger.group(
"Waiting for SSH to become available", _numbered=("[]", 1, 5)):
with LogTimer(self.log_prefix + "Got remote shell"):
cli_logger.old_info(logger, "{}Waiting for remote shell...",
self.log_prefix)
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
try:
logger.debug(self.log_prefix +
"Waiting for remote shell...")
cli_logger.print("Running `{}` as a test.", cf.bold("uptime"))
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
try:
cli_logger.old_debug(logger,
"{}Waiting for remote shell...",
self.log_prefix)
self.cmd_runner.run("uptime", timeout=5)
logger.debug("Uptime succeeded.")
return True
self.cmd_runner.run("uptime")
cli_logger.old_debug(logger, "Uptime succeeded.")
cli_logger.success("Success.")
return True
except Exception as e:
retry_str = str(e)
if hasattr(e, "cmd"):
retry_str = "(Exit Status {}): {}".format(
e.returncode, " ".join(e.cmd))
except Exception as e:
retry_str = str(e)
if hasattr(e, "cmd"):
retry_str = "(Exit Status {}): {}".format(
e.returncode, " ".join(e.cmd))
logger.debug(self.log_prefix +
"Node not up, retrying: {}".format(retry_str))
time.sleep(READY_CHECK_INTERVAL)
cli_logger.print(
"SSH still not available {}, "
"retrying in {} seconds.", cf.gray(retry_str),
cf.bold(str(READY_CHECK_INTERVAL)))
cli_logger.old_debug(logger,
"{}Node not up, retrying: {}",
self.log_prefix, retry_str)
time.sleep(READY_CHECK_INTERVAL)
assert False, "Unable to connect to node"
def do_update(self):
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_WAITING_FOR_SSH})
cli_logger.labeled_value("New status", STATUS_WAITING_FOR_SSH)
deadline = time.time() + NODE_START_WAIT_S
self.wait_ready(deadline)
node_tags = self.provider.node_tags(self.node_id)
logger.debug("Node tags: {}".format(str(node_tags)))
if node_tags.get(TAG_RAY_RUNTIME_CONFIG) == self.runtime_hash:
logger.info(self.log_prefix +
"{} already up-to-date, skip to ray start".format(
self.node_id))
# todo: we lie in the confirmation message since
# full setup might be cancelled here
cli_logger.print(
"Configuration already up to date, "
"skipping file mounts, initalization and setup commands.")
cli_logger.old_info(logger,
"{}{} already up-to-date, skip to ray start",
self.log_prefix, self.node_id)
else:
cli_logger.print(
"Updating cluster configuration.",
_tags=dict(hash=self.runtime_hash))
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SYNCING_FILES})
cli_logger.labeled_value("New status", STATUS_SYNCING_FILES)
self.sync_file_mounts(self.rsync_up)
# Run init commands
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: STATUS_SETTING_UP})
with LogTimer(
self.log_prefix + "Initialization commands",
show_status=True):
for cmd in self.initialization_commands:
self.cmd_runner.run(
cmd,
ssh_options_override=SSHOptions(
self.auth_config.get("ssh_private_key")))
cli_logger.labeled_value("New status", STATUS_SETTING_UP)
if self.initialization_commands:
with cli_logger.group(
"Running initialization commands",
_numbered=("[]", 3, 5)): # todo: fix command numbering
with LogTimer(
self.log_prefix + "Initialization commands",
show_status=True):
for cmd in self.initialization_commands:
self.cmd_runner.run(
cmd,
ssh_options_override=SSHOptions(
self.auth_config.get("ssh_private_key")))
else:
cli_logger.print(
"No initialization commands to run.",
_numbered=("[]", 3, 5))
if self.setup_commands:
with cli_logger.group(
"Running setup commands",
_numbered=("[]", 4, 5)): # todo: fix command numbering
with LogTimer(
self.log_prefix + "Setup commands",
show_status=True):
total = len(self.setup_commands)
for i, cmd in enumerate(self.setup_commands):
if cli_logger.verbosity == 0:
cmd_to_print = cf.bold(cmd[:30]) + "..."
else:
cmd_to_print = cf.bold(cmd)
cli_logger.print(
cmd_to_print, _numbered=("()", i, total))
self.cmd_runner.run(cmd)
else:
cli_logger.print(
"No setup commands to run.", _numbered=("[]", 4, 5))
with cli_logger.group(
"Starting the Ray runtime", _numbered=("[]", 5, 5)):
with LogTimer(
self.log_prefix + "Setup commands", show_status=True):
for cmd in self.setup_commands:
self.log_prefix + "Ray start commands", show_status=True):
for cmd in self.ray_start_commands:
self.cmd_runner.run(cmd)
with LogTimer(
self.log_prefix + "Ray start commands", show_status=True):
for cmd in self.ray_start_commands:
self.cmd_runner.run(cmd)
def rsync_up(self, source, target):
logger.info(self.log_prefix +
"Syncing {} to {}...".format(source, target))
cli_logger.old_info(logger, "{}Syncing {} to {}...", self.log_prefix,
source, target)
self.cmd_runner.run_rsync_up(source, target)
cli_logger.verbose("`rsync`ed {} (local) to {} (remote)",
cf.bold(source), cf.bold(target))
def rsync_down(self, source, target):
logger.info(self.log_prefix +
"Syncing {} from {}...".format(source, target))
cli_logger.old_info(logger, "{}Syncing {} from {}...", self.log_prefix,
source, target)
self.cmd_runner.run_rsync_down(source, target)
cli_logger.verbose("`rsync`ed {} (remote) to {} (local)",
cf.bold(source), cf.bold(target))
class NodeUpdaterThread(NodeUpdater, Thread):

View file

@ -646,6 +646,16 @@ def stop(force, verbose):
@cli.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--min-workers",
required=False,
type=int,
help="Override the configured min worker node count for the cluster.")
@click.option(
"--max-workers",
required=False,
type=int,
help="Override the configured max worker node count for the cluster.")
@click.option(
"--no-restart",
is_flag=True,
@ -659,20 +669,11 @@ def stop(force, verbose):
help=("Whether to skip running setup commands and only restart Ray. "
"This cannot be used with 'no-restart'."))
@click.option(
"--no-config-cache",
"--yes",
"-y",
is_flag=True,
default=False,
help="Disable the local cluster config cache.")
@click.option(
"--min-workers",
required=False,
type=int,
help="Override the configured min worker node count for the cluster.")
@click.option(
"--max-workers",
required=False,
type=int,
help="Override the configured max worker node count for the cluster.")
help="Don't ask for confirmation.")
@click.option(
"--cluster-name",
"-n",
@ -680,13 +681,25 @@ def stop(force, verbose):
type=str,
help="Override the configured cluster name.")
@click.option(
"--yes",
"-y",
"--no-config-cache",
is_flag=True,
default=False,
help="Don't ask for confirmation.")
help="Disable the local cluster config cache.")
@click.option(
"--log-old-style/--log-new-style",
is_flag=True,
default=True,
help=("Use old logging."))
@click.option(
"--log-color",
required=False,
type=str,
default="auto",
help=("Use color logging. "
"Valid values are: auto (if stdout is a tty), true, false."))
@click.option("-v", "--verbose", count=True)
def up(cluster_config_file, min_workers, max_workers, no_restart, restart_only,
yes, cluster_name, no_config_cache):
yes, cluster_name, no_config_cache, log_old_style, log_color, verbose):
"""Create or update a Ray cluster."""
if restart_only or no_restart:
assert restart_only != no_restart, "Cannot set both 'restart_only' " \
@ -703,38 +716,52 @@ def up(cluster_config_file, min_workers, max_workers, no_restart, restart_only,
logger.info("Error downloading file: ", e)
create_or_update_cluster(cluster_config_file, min_workers, max_workers,
no_restart, restart_only, yes, cluster_name,
no_config_cache)
no_config_cache, log_old_style, log_color,
verbose)
@cli.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--workers-only",
is_flag=True,
default=False,
help="Only destroy the workers.")
@click.option(
"--keep-min-workers",
is_flag=True,
default=False,
help="Retain the minimal amount of workers specified in the config.")
@click.option(
"--yes",
"-y",
is_flag=True,
default=False,
help="Don't ask for confirmation.")
@click.option(
"--workers-only",
is_flag=True,
default=False,
help="Only destroy the workers.")
@click.option(
"--cluster-name",
"-n",
required=False,
type=str,
help="Override the configured cluster name.")
@click.option(
"--keep-min-workers",
is_flag=True,
default=False,
help="Retain the minimal amount of workers specified in the config.")
@click.option(
"--log-old-style/--log-new-style",
is_flag=True,
default=True,
help=("Use old logging."))
@click.option(
"--log-color",
required=False,
type=str,
default="auto",
help=("Use color logging. "
"Valid values are: auto (if stdout is a tty), true, false."))
@click.option("-v", "--verbose", count=True)
def down(cluster_config_file, yes, workers_only, cluster_name,
keep_min_workers):
keep_min_workers, log_old_style, log_color, verbose):
"""Tear down a Ray cluster."""
teardown_cluster(cluster_config_file, yes, workers_only, cluster_name,
keep_min_workers)
keep_min_workers, log_old_style, log_color, verbose)
@cli.command()

View file

@ -303,6 +303,7 @@ install_requires = [
"aiohttp",
"click >= 7.0",
"colorama",
"colorful",
"filelock",
"google",
"gpustat",