From 1345802c39a551ed92a6227f1f4b4395bb59b77b Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Wed, 31 Jul 2019 13:45:05 -0700 Subject: [PATCH] [autoscaler] Change sys.exit(1) in update ssh_cmd (#5266) --- python/ray/autoscaler/commands.py | 14 ++----- python/ray/autoscaler/updater.py | 66 ++++++++++++++++++------------- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index be1b5a144..91718d8fc 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -374,13 +374,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start, shutdown_cmd = wrap_docker(shutdown_cmd) cmd += ("; {}; sudo shutdown -h now".format(shutdown_cmd)) - _exec( - updater, - cmd, - screen, - tmux, - expect_error=stop, - port_forward=port_forward) + _exec(updater, cmd, screen, tmux, port_forward=port_forward) if tmux or screen: attach_command_parts = ["ray attach", config_file] @@ -400,7 +394,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start, provider.cleanup() -def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None): +def _exec(updater, cmd, screen, tmux, port_forward=None): if cmd: if screen: cmd = [ @@ -418,7 +412,7 @@ def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None): updater.ssh_cmd( cmd, allocate_tty=True, - expect_error=expect_error, + exit_on_fail=True, port_forward=port_forward) @@ -461,7 +455,7 @@ def rsync(config_file, source, target, override_cluster_name, down): rsync = updater.rsync_up if source and target: - rsync(source, target, check_error=False) + rsync(source, target) else: updater.sync_file_mounts(rsync) diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 9ddf9b56f..c5c3d3c4c 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -54,6 +54,7 @@ class NodeUpdater(object): setup_commands, runtime_hash, process_runner=subprocess, + exit_on_update_fail=False, use_internal_ip=False): ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format( @@ -75,14 +76,9 @@ class NodeUpdater(object): } self.initialization_commands = initialization_commands self.setup_commands = setup_commands + self.exit_on_update_fail = exit_on_update_fail self.runtime_hash = runtime_hash - def get_caller(self, check_error): - if check_error: - return self.process_runner.call - else: - return self.process_runner.check_call - def get_node_ip(self): if self.use_internal_ip: return self.provider.internal_ip(self.node_id) @@ -118,15 +114,21 @@ class NodeUpdater(object): # the ControlPath directory exists, allowing SSH to maintain # persistent sessions later on. with open("/dev/null", "w") as redirect: - self.get_caller(False)( - ["mkdir", "-p", self.ssh_control_path], - stdout=redirect, - stderr=redirect) + try: + self.process_runner.check_call( + ["mkdir", "-p", self.ssh_control_path], + stdout=redirect, + stderr=redirect) + except subprocess.CalledProcessError as e: + logger.warning(e) - self.get_caller(False)( - ["chmod", "0700", self.ssh_control_path], - stdout=redirect, - stderr=redirect) + try: + self.process_runner.check_call( + ["chmod", "0700", self.ssh_control_path], + stdout=redirect, + stderr=redirect) + except subprocess.CalledProcessError as e: + logger.warning(e) def run(self): logger.info("NodeUpdater: " @@ -226,19 +228,19 @@ class NodeUpdater(object): m = "{}: Initialization commands completed".format(self.node_id) with LogTimer("NodeUpdater: {}".format(m)): for cmd in self.initialization_commands: - self.ssh_cmd(cmd) + self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) m = "{}: Setup commands completed".format(self.node_id) with LogTimer("NodeUpdater: {}".format(m)): for cmd in self.setup_commands: - self.ssh_cmd(cmd) + self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail) - def rsync_up(self, source, target, redirect=None, check_error=True): + def rsync_up(self, source, target, redirect=None): logger.info("NodeUpdater: " "{}: Syncing {} to {}...".format(self.node_id, source, target)) self.set_ssh_ip_if_required() - self.get_caller(check_error)( + self.process_runner.check_call( [ "rsync", "-e", " ".join(["ssh"] + get_default_ssh_options( self.ssh_private_key, 120, self.ssh_control_path)), "-avz", @@ -247,12 +249,12 @@ class NodeUpdater(object): stdout=redirect or sys.stdout, stderr=redirect or sys.stderr) - def rsync_down(self, source, target, redirect=None, check_error=True): + def rsync_down(self, source, target, redirect=None): logger.info("NodeUpdater: " "{}: Syncing {} from {}...".format(self.node_id, source, target)) self.set_ssh_ip_if_required() - self.get_caller(check_error)( + self.process_runner.check_call( [ "rsync", "-e", " ".join(["ssh"] + get_default_ssh_options( self.ssh_private_key, 120, self.ssh_control_path)), "-avz", @@ -267,7 +269,7 @@ class NodeUpdater(object): redirect=None, allocate_tty=False, emulate_interactive=True, - expect_error=False, + exit_on_fail=False, port_forward=None): self.set_ssh_ip_if_required() @@ -291,12 +293,22 @@ class NodeUpdater(object): "-L", "{}:localhost:{}".format(port_forward, port_forward) ] - self.get_caller(expect_error)( - ssh + ssh_opt + get_default_ssh_options( - self.ssh_private_key, connect_timeout, self.ssh_control_path) + - ["{}@{}".format(self.ssh_user, self.ssh_ip), cmd], - stdout=redirect or sys.stdout, - stderr=redirect or sys.stderr) + final_cmd = ssh + ssh_opt + get_default_ssh_options( + self.ssh_private_key, connect_timeout, self.ssh_control_path) + [ + "{}@{}".format(self.ssh_user, self.ssh_ip), cmd + ] + try: + self.process_runner.check_call( + final_cmd, + stdout=redirect or sys.stdout, + stderr=redirect or sys.stderr) + except subprocess.CalledProcessError: + if exit_on_fail: + logger.error("Command failed: \n\n {}\n".format( + " ".join(final_cmd))) + sys.exit(1) + else: + raise class NodeUpdaterThread(NodeUpdater, Thread):