diff --git a/doc/source/cluster/config.rst b/doc/source/cluster/config.rst index 64457b099..f979c25af 100644 --- a/doc/source/cluster/config.rst +++ b/doc/source/cluster/config.rst @@ -63,8 +63,7 @@ The ``example-full.yaml`` configuration is enough to get started with Ray, but f InstanceType: p2.8xlarge **Docker**: Specify docker image. This executes all commands on all nodes in the docker container, -and opens all the necessary ports to support the Ray cluster. It will also automatically install -Docker if Docker is not installed. This currently does not have GPU support. +and opens all the necessary ports to support the Ray cluster. .. code-block:: yaml @@ -72,6 +71,17 @@ Docker if Docker is not installed. This currently does not have GPU support. image: tensorflow/tensorflow:1.5.0-py3 container_name: ray_docker +If Docker is not installed, add the following commands to ``initialization_commands`` to install it. + +.. code-block:: yaml + + initialization_commands: + - curl -fsSL https://get.docker.com -o get-docker.sh + - sudo sh get-docker.sh + - sudo usermod -aG docker $USER + - sudo systemctl restart docker -f + + **Mixed GPU and CPU nodes**: for RL applications that require proportionally more CPU than GPU resources, you can use additional CPU workers with a GPU head node. diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 467ca3b44..3d076c535 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -147,6 +147,44 @@ class KubernetesCommandRunner: self.node_id) +class SSHOptions: + def __init__(self, ssh_key, control_path=None, **kwargs): + self.ssh_key = ssh_key + self.arg_dict = { + # Supresses initial fingerprint verification. + "StrictHostKeyChecking": "no", + # SSH IP and fingerprint pairs no longer added to known_hosts. + # This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED" + # warning if a new node has the same IP as a previously + # deleted node, because the fingerprints will not match in + # that case. + "UserKnownHostsFile": os.devnull, + # Try fewer extraneous key pairs. + "IdentitiesOnly": "yes", + # Abort if port forwarding fails (instead of just printing to + # stderr). + "ExitOnForwardFailure": "yes", + # Quickly kill the connection if network connection breaks (as + # opposed to hanging/blocking). + "ServerAliveInterval": 5, + "ServerAliveCountMax": 3 + } + if control_path: + self.arg_dict.update({ + "ControlMaster": "auto", + "ControlPath": "{}/%C".format(control_path), + "ControlPersist": "10s", + }) + self.arg_dict.update(kwargs) + + def to_ssh_options_list(self, *, timeout=60): + self.arg_dict["ConnectTimeout"] = "{}s".format(timeout) + return ["-i", self.ssh_key] + [ + x for y in (["-o", "{}={}".format(k, v)] + for k, v in self.arg_dict.items()) for x in y + ] + + class SSHCommandRunner: def __init__(self, log_prefix, node_id, provider, auth_config, cluster_name, process_runner, use_internal_ip): @@ -166,36 +204,8 @@ class SSHCommandRunner: self.ssh_user = auth_config["ssh_user"] self.ssh_control_path = ssh_control_path self.ssh_ip = None - - def get_default_ssh_options(self, connect_timeout): - OPTS = [ - ("ConnectTimeout", "{}s".format(connect_timeout)), - # Supresses initial fingerprint verification. - ("StrictHostKeyChecking", "no"), - # SSH IP and fingerprint pairs no longer added to known_hosts. - # This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED" - # warning if a new node has the same IP as a previously - # deleted node, because the fingerprints will not match in - # that case. - ("UserKnownHostsFile", os.devnull), - ("ControlMaster", "auto"), - ("ControlPath", "{}/%C".format(self.ssh_control_path)), - ("ControlPersist", "10s"), - # Try fewer extraneous key pairs. - ("IdentitiesOnly", "yes"), - # Abort if port forwarding fails (instead of just printing to - # stderr). - ("ExitOnForwardFailure", "yes"), - # Quickly kill the connection if network connection breaks (as - # opposed to hanging/blocking). - ("ServerAliveInterval", 5), - ("ServerAliveCountMax", 3), - ] - - return ["-i", self.ssh_private_key] + [ - x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS) - for x in y - ] + self.base_ssh_options = SSHOptions(self.ssh_private_key, + self.ssh_control_path) def get_node_ip(self): if self.use_internal_ip: @@ -241,7 +251,14 @@ class SSHCommandRunner: exit_on_fail=False, port_forward=None, with_output=False, + ssh_options_override=None, **kwargs): + ssh_options = ssh_options_override or self.base_ssh_options + + assert isinstance( + ssh_options, SSHOptions + ), "ssh_options must be of type SSHOptions, got {}".format( + type(ssh_options)) self.set_ssh_ip_if_required() @@ -255,7 +272,7 @@ class SSHCommandRunner: "{} -> localhost:{}".format(local, remote)) ssh += ["-L", "{}:localhost:{}".format(remote, local)] - final_cmd = ssh + self.get_default_ssh_options(timeout) + [ + final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [ "{}@{}".format(self.ssh_user, self.ssh_ip) ] if cmd: @@ -286,16 +303,20 @@ class SSHCommandRunner: self.set_ssh_ip_if_required() self.process_runner.check_call([ "rsync", "--rsh", - " ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz", - source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip, target) + " ".join(["ssh"] + + self.base_ssh_options.to_ssh_options_list(timeout=120)), + "-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip, + target) ]) def run_rsync_down(self, source, target): self.set_ssh_ip_if_required() self.process_runner.check_call([ "rsync", "--rsh", - " ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz", - "{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target + " ".join(["ssh"] + + self.base_ssh_options.to_ssh_options_list(timeout=120)), + "-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip, + source), target ]) def remote_shell_command_str(self): @@ -309,6 +330,7 @@ class DockerCommandRunner(SSHCommandRunner): self.docker_name = docker_config["container_name"] self.docker_config = docker_config self.home_dir = None + self.check_docker_installed() self.shutdown = False def run(self, @@ -318,6 +340,7 @@ class DockerCommandRunner(SSHCommandRunner): port_forward=None, with_output=False, run_env=True, + ssh_options_override=None, **kwargs): if run_env == "auto": run_env = "host" if cmd.find("docker") == 0 else "docker" @@ -335,7 +358,23 @@ class DockerCommandRunner(SSHCommandRunner): timeout=timeout, exit_on_fail=exit_on_fail, port_forward=None, - with_output=False) + with_output=False, + ssh_options_override=ssh_options_override) + + def check_docker_installed(self): + try: + self.ssh_command_runner.run("command -v docker") + return + except Exception: + install_commands = [ + "curl -fsSL https://get.docker.com -o get-docker.sh", + "sudo sh get-docker.sh", "sudo usermod -aG docker $USER", + "sudo systemctl restart docker -f" + ] + logger.error( + "Docker not installed. You can install Docker by adding the " + "following commands to 'initialization_commands':\n" + + "\n".join(install_commands)) def shutdown_after_next_cmd(self): self.shutdown = True @@ -422,6 +461,7 @@ class NodeUpdater: self.setup_commands = setup_commands self.ray_start_commands = ray_start_commands self.runtime_hash = runtime_hash + self.auth_config = auth_config def run(self): logger.info(self.log_prefix + @@ -516,7 +556,10 @@ class NodeUpdater: self.log_prefix + "Initialization commands", show_status=True): for cmd in self.initialization_commands: - self.cmd_runner.run(cmd) + self.cmd_runner.run( + cmd, + ssh_options_override=SSHOptions( + self.auth_config.get("ssh_private_key"))) with LogTimer( self.log_prefix + "Setup commands", show_status=True):