[autoscaler] Run initialization_commands without a persistent connection (#9020)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
Ian Rodney 2020-07-06 16:34:59 -07:00 committed by GitHub
parent 139d21e068
commit 6fecd3cfce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 39 deletions

View file

@ -63,8 +63,7 @@ The ``example-full.yaml`` configuration is enough to get started with Ray, but f
InstanceType: p2.8xlarge
**Docker**: Specify docker image. This executes all commands on all nodes in the docker container,
and opens all the necessary ports to support the Ray cluster. It will also automatically install
Docker if Docker is not installed. This currently does not have GPU support.
and opens all the necessary ports to support the Ray cluster.
.. code-block:: yaml
@ -72,6 +71,17 @@ Docker if Docker is not installed. This currently does not have GPU support.
image: tensorflow/tensorflow:1.5.0-py3
container_name: ray_docker
If Docker is not installed, add the following commands to ``initialization_commands`` to install it.
.. code-block:: yaml
initialization_commands:
- curl -fsSL https://get.docker.com -o get-docker.sh
- sudo sh get-docker.sh
- sudo usermod -aG docker $USER
- sudo systemctl restart docker -f
**Mixed GPU and CPU nodes**: for RL applications that require proportionally more
CPU than GPU resources, you can use additional CPU workers with a GPU head node.

View file

@ -147,6 +147,44 @@ class KubernetesCommandRunner:
self.node_id)
class SSHOptions:
def __init__(self, ssh_key, control_path=None, **kwargs):
self.ssh_key = ssh_key
self.arg_dict = {
# Supresses initial fingerprint verification.
"StrictHostKeyChecking": "no",
# SSH IP and fingerprint pairs no longer added to known_hosts.
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
# warning if a new node has the same IP as a previously
# deleted node, because the fingerprints will not match in
# that case.
"UserKnownHostsFile": os.devnull,
# Try fewer extraneous key pairs.
"IdentitiesOnly": "yes",
# Abort if port forwarding fails (instead of just printing to
# stderr).
"ExitOnForwardFailure": "yes",
# Quickly kill the connection if network connection breaks (as
# opposed to hanging/blocking).
"ServerAliveInterval": 5,
"ServerAliveCountMax": 3
}
if control_path:
self.arg_dict.update({
"ControlMaster": "auto",
"ControlPath": "{}/%C".format(control_path),
"ControlPersist": "10s",
})
self.arg_dict.update(kwargs)
def to_ssh_options_list(self, *, timeout=60):
self.arg_dict["ConnectTimeout"] = "{}s".format(timeout)
return ["-i", self.ssh_key] + [
x for y in (["-o", "{}={}".format(k, v)]
for k, v in self.arg_dict.items()) for x in y
]
class SSHCommandRunner:
def __init__(self, log_prefix, node_id, provider, auth_config,
cluster_name, process_runner, use_internal_ip):
@ -166,36 +204,8 @@ class SSHCommandRunner:
self.ssh_user = auth_config["ssh_user"]
self.ssh_control_path = ssh_control_path
self.ssh_ip = None
def get_default_ssh_options(self, connect_timeout):
OPTS = [
("ConnectTimeout", "{}s".format(connect_timeout)),
# Supresses initial fingerprint verification.
("StrictHostKeyChecking", "no"),
# SSH IP and fingerprint pairs no longer added to known_hosts.
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
# warning if a new node has the same IP as a previously
# deleted node, because the fingerprints will not match in
# that case.
("UserKnownHostsFile", os.devnull),
("ControlMaster", "auto"),
("ControlPath", "{}/%C".format(self.ssh_control_path)),
("ControlPersist", "10s"),
# Try fewer extraneous key pairs.
("IdentitiesOnly", "yes"),
# Abort if port forwarding fails (instead of just printing to
# stderr).
("ExitOnForwardFailure", "yes"),
# Quickly kill the connection if network connection breaks (as
# opposed to hanging/blocking).
("ServerAliveInterval", 5),
("ServerAliveCountMax", 3),
]
return ["-i", self.ssh_private_key] + [
x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS)
for x in y
]
self.base_ssh_options = SSHOptions(self.ssh_private_key,
self.ssh_control_path)
def get_node_ip(self):
if self.use_internal_ip:
@ -241,7 +251,14 @@ class SSHCommandRunner:
exit_on_fail=False,
port_forward=None,
with_output=False,
ssh_options_override=None,
**kwargs):
ssh_options = ssh_options_override or self.base_ssh_options
assert isinstance(
ssh_options, SSHOptions
), "ssh_options must be of type SSHOptions, got {}".format(
type(ssh_options))
self.set_ssh_ip_if_required()
@ -255,7 +272,7 @@ class SSHCommandRunner:
"{} -> localhost:{}".format(local, remote))
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
final_cmd = ssh + self.get_default_ssh_options(timeout) + [
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
"{}@{}".format(self.ssh_user, self.ssh_ip)
]
if cmd:
@ -286,16 +303,20 @@ class SSHCommandRunner:
self.set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz",
source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip, target)
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
target)
])
def run_rsync_down(self, source, target):
self.set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz",
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
])
def remote_shell_command_str(self):
@ -309,6 +330,7 @@ class DockerCommandRunner(SSHCommandRunner):
self.docker_name = docker_config["container_name"]
self.docker_config = docker_config
self.home_dir = None
self.check_docker_installed()
self.shutdown = False
def run(self,
@ -318,6 +340,7 @@ class DockerCommandRunner(SSHCommandRunner):
port_forward=None,
with_output=False,
run_env=True,
ssh_options_override=None,
**kwargs):
if run_env == "auto":
run_env = "host" if cmd.find("docker") == 0 else "docker"
@ -335,7 +358,23 @@ class DockerCommandRunner(SSHCommandRunner):
timeout=timeout,
exit_on_fail=exit_on_fail,
port_forward=None,
with_output=False)
with_output=False,
ssh_options_override=ssh_options_override)
def check_docker_installed(self):
try:
self.ssh_command_runner.run("command -v docker")
return
except Exception:
install_commands = [
"curl -fsSL https://get.docker.com -o get-docker.sh",
"sudo sh get-docker.sh", "sudo usermod -aG docker $USER",
"sudo systemctl restart docker -f"
]
logger.error(
"Docker not installed. You can install Docker by adding the "
"following commands to 'initialization_commands':\n" +
"\n".join(install_commands))
def shutdown_after_next_cmd(self):
self.shutdown = True
@ -422,6 +461,7 @@ class NodeUpdater:
self.setup_commands = setup_commands
self.ray_start_commands = ray_start_commands
self.runtime_hash = runtime_hash
self.auth_config = auth_config
def run(self):
logger.info(self.log_prefix +
@ -516,7 +556,10 @@ class NodeUpdater:
self.log_prefix + "Initialization commands",
show_status=True):
for cmd in self.initialization_commands:
self.cmd_runner.run(cmd)
self.cmd_runner.run(
cmd,
ssh_options_override=SSHOptions(
self.auth_config.get("ssh_private_key")))
with LogTimer(
self.log_prefix + "Setup commands", show_status=True):