mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[autoscaler] Run initialization_commands without a persistent connection (#9020)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
parent
139d21e068
commit
6fecd3cfce
2 changed files with 92 additions and 39 deletions
|
@ -63,8 +63,7 @@ The ``example-full.yaml`` configuration is enough to get started with Ray, but f
|
|||
InstanceType: p2.8xlarge
|
||||
|
||||
**Docker**: Specify docker image. This executes all commands on all nodes in the docker container,
|
||||
and opens all the necessary ports to support the Ray cluster. It will also automatically install
|
||||
Docker if Docker is not installed. This currently does not have GPU support.
|
||||
and opens all the necessary ports to support the Ray cluster.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
|
@ -72,6 +71,17 @@ Docker if Docker is not installed. This currently does not have GPU support.
|
|||
image: tensorflow/tensorflow:1.5.0-py3
|
||||
container_name: ray_docker
|
||||
|
||||
If Docker is not installed, add the following commands to ``initialization_commands`` to install it.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
initialization_commands:
|
||||
- curl -fsSL https://get.docker.com -o get-docker.sh
|
||||
- sudo sh get-docker.sh
|
||||
- sudo usermod -aG docker $USER
|
||||
- sudo systemctl restart docker -f
|
||||
|
||||
|
||||
**Mixed GPU and CPU nodes**: for RL applications that require proportionally more
|
||||
CPU than GPU resources, you can use additional CPU workers with a GPU head node.
|
||||
|
||||
|
|
|
@ -147,6 +147,44 @@ class KubernetesCommandRunner:
|
|||
self.node_id)
|
||||
|
||||
|
||||
class SSHOptions:
|
||||
def __init__(self, ssh_key, control_path=None, **kwargs):
|
||||
self.ssh_key = ssh_key
|
||||
self.arg_dict = {
|
||||
# Supresses initial fingerprint verification.
|
||||
"StrictHostKeyChecking": "no",
|
||||
# SSH IP and fingerprint pairs no longer added to known_hosts.
|
||||
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
|
||||
# warning if a new node has the same IP as a previously
|
||||
# deleted node, because the fingerprints will not match in
|
||||
# that case.
|
||||
"UserKnownHostsFile": os.devnull,
|
||||
# Try fewer extraneous key pairs.
|
||||
"IdentitiesOnly": "yes",
|
||||
# Abort if port forwarding fails (instead of just printing to
|
||||
# stderr).
|
||||
"ExitOnForwardFailure": "yes",
|
||||
# Quickly kill the connection if network connection breaks (as
|
||||
# opposed to hanging/blocking).
|
||||
"ServerAliveInterval": 5,
|
||||
"ServerAliveCountMax": 3
|
||||
}
|
||||
if control_path:
|
||||
self.arg_dict.update({
|
||||
"ControlMaster": "auto",
|
||||
"ControlPath": "{}/%C".format(control_path),
|
||||
"ControlPersist": "10s",
|
||||
})
|
||||
self.arg_dict.update(kwargs)
|
||||
|
||||
def to_ssh_options_list(self, *, timeout=60):
|
||||
self.arg_dict["ConnectTimeout"] = "{}s".format(timeout)
|
||||
return ["-i", self.ssh_key] + [
|
||||
x for y in (["-o", "{}={}".format(k, v)]
|
||||
for k, v in self.arg_dict.items()) for x in y
|
||||
]
|
||||
|
||||
|
||||
class SSHCommandRunner:
|
||||
def __init__(self, log_prefix, node_id, provider, auth_config,
|
||||
cluster_name, process_runner, use_internal_ip):
|
||||
|
@ -166,36 +204,8 @@ class SSHCommandRunner:
|
|||
self.ssh_user = auth_config["ssh_user"]
|
||||
self.ssh_control_path = ssh_control_path
|
||||
self.ssh_ip = None
|
||||
|
||||
def get_default_ssh_options(self, connect_timeout):
|
||||
OPTS = [
|
||||
("ConnectTimeout", "{}s".format(connect_timeout)),
|
||||
# Supresses initial fingerprint verification.
|
||||
("StrictHostKeyChecking", "no"),
|
||||
# SSH IP and fingerprint pairs no longer added to known_hosts.
|
||||
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
|
||||
# warning if a new node has the same IP as a previously
|
||||
# deleted node, because the fingerprints will not match in
|
||||
# that case.
|
||||
("UserKnownHostsFile", os.devnull),
|
||||
("ControlMaster", "auto"),
|
||||
("ControlPath", "{}/%C".format(self.ssh_control_path)),
|
||||
("ControlPersist", "10s"),
|
||||
# Try fewer extraneous key pairs.
|
||||
("IdentitiesOnly", "yes"),
|
||||
# Abort if port forwarding fails (instead of just printing to
|
||||
# stderr).
|
||||
("ExitOnForwardFailure", "yes"),
|
||||
# Quickly kill the connection if network connection breaks (as
|
||||
# opposed to hanging/blocking).
|
||||
("ServerAliveInterval", 5),
|
||||
("ServerAliveCountMax", 3),
|
||||
]
|
||||
|
||||
return ["-i", self.ssh_private_key] + [
|
||||
x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS)
|
||||
for x in y
|
||||
]
|
||||
self.base_ssh_options = SSHOptions(self.ssh_private_key,
|
||||
self.ssh_control_path)
|
||||
|
||||
def get_node_ip(self):
|
||||
if self.use_internal_ip:
|
||||
|
@ -241,7 +251,14 @@ class SSHCommandRunner:
|
|||
exit_on_fail=False,
|
||||
port_forward=None,
|
||||
with_output=False,
|
||||
ssh_options_override=None,
|
||||
**kwargs):
|
||||
ssh_options = ssh_options_override or self.base_ssh_options
|
||||
|
||||
assert isinstance(
|
||||
ssh_options, SSHOptions
|
||||
), "ssh_options must be of type SSHOptions, got {}".format(
|
||||
type(ssh_options))
|
||||
|
||||
self.set_ssh_ip_if_required()
|
||||
|
||||
|
@ -255,7 +272,7 @@ class SSHCommandRunner:
|
|||
"{} -> localhost:{}".format(local, remote))
|
||||
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
|
||||
|
||||
final_cmd = ssh + self.get_default_ssh_options(timeout) + [
|
||||
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip)
|
||||
]
|
||||
if cmd:
|
||||
|
@ -286,16 +303,20 @@ class SSHCommandRunner:
|
|||
self.set_ssh_ip_if_required()
|
||||
self.process_runner.check_call([
|
||||
"rsync", "--rsh",
|
||||
" ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz",
|
||||
source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip, target)
|
||||
" ".join(["ssh"] +
|
||||
self.base_ssh_options.to_ssh_options_list(timeout=120)),
|
||||
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
|
||||
target)
|
||||
])
|
||||
|
||||
def run_rsync_down(self, source, target):
|
||||
self.set_ssh_ip_if_required()
|
||||
self.process_runner.check_call([
|
||||
"rsync", "--rsh",
|
||||
" ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz",
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
|
||||
" ".join(["ssh"] +
|
||||
self.base_ssh_options.to_ssh_options_list(timeout=120)),
|
||||
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
|
||||
source), target
|
||||
])
|
||||
|
||||
def remote_shell_command_str(self):
|
||||
|
@ -309,6 +330,7 @@ class DockerCommandRunner(SSHCommandRunner):
|
|||
self.docker_name = docker_config["container_name"]
|
||||
self.docker_config = docker_config
|
||||
self.home_dir = None
|
||||
self.check_docker_installed()
|
||||
self.shutdown = False
|
||||
|
||||
def run(self,
|
||||
|
@ -318,6 +340,7 @@ class DockerCommandRunner(SSHCommandRunner):
|
|||
port_forward=None,
|
||||
with_output=False,
|
||||
run_env=True,
|
||||
ssh_options_override=None,
|
||||
**kwargs):
|
||||
if run_env == "auto":
|
||||
run_env = "host" if cmd.find("docker") == 0 else "docker"
|
||||
|
@ -335,7 +358,23 @@ class DockerCommandRunner(SSHCommandRunner):
|
|||
timeout=timeout,
|
||||
exit_on_fail=exit_on_fail,
|
||||
port_forward=None,
|
||||
with_output=False)
|
||||
with_output=False,
|
||||
ssh_options_override=ssh_options_override)
|
||||
|
||||
def check_docker_installed(self):
|
||||
try:
|
||||
self.ssh_command_runner.run("command -v docker")
|
||||
return
|
||||
except Exception:
|
||||
install_commands = [
|
||||
"curl -fsSL https://get.docker.com -o get-docker.sh",
|
||||
"sudo sh get-docker.sh", "sudo usermod -aG docker $USER",
|
||||
"sudo systemctl restart docker -f"
|
||||
]
|
||||
logger.error(
|
||||
"Docker not installed. You can install Docker by adding the "
|
||||
"following commands to 'initialization_commands':\n" +
|
||||
"\n".join(install_commands))
|
||||
|
||||
def shutdown_after_next_cmd(self):
|
||||
self.shutdown = True
|
||||
|
@ -422,6 +461,7 @@ class NodeUpdater:
|
|||
self.setup_commands = setup_commands
|
||||
self.ray_start_commands = ray_start_commands
|
||||
self.runtime_hash = runtime_hash
|
||||
self.auth_config = auth_config
|
||||
|
||||
def run(self):
|
||||
logger.info(self.log_prefix +
|
||||
|
@ -516,7 +556,10 @@ class NodeUpdater:
|
|||
self.log_prefix + "Initialization commands",
|
||||
show_status=True):
|
||||
for cmd in self.initialization_commands:
|
||||
self.cmd_runner.run(cmd)
|
||||
self.cmd_runner.run(
|
||||
cmd,
|
||||
ssh_options_override=SSHOptions(
|
||||
self.auth_config.get("ssh_private_key")))
|
||||
|
||||
with LogTimer(
|
||||
self.log_prefix + "Setup commands", show_status=True):
|
||||
|
|
Loading…
Add table
Reference in a new issue