[autoscaler] Change sys.exit(1) in update ssh_cmd (#5266)

This commit is contained in:
Kristian Hartikainen 2019-07-31 13:45:05 -07:00 committed by Richard Liaw
parent b3c8091a35
commit 1345802c39
2 changed files with 43 additions and 37 deletions

View file

@ -374,13 +374,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
shutdown_cmd = wrap_docker(shutdown_cmd) shutdown_cmd = wrap_docker(shutdown_cmd)
cmd += ("; {}; sudo shutdown -h now".format(shutdown_cmd)) cmd += ("; {}; sudo shutdown -h now".format(shutdown_cmd))
_exec( _exec(updater, cmd, screen, tmux, port_forward=port_forward)
updater,
cmd,
screen,
tmux,
expect_error=stop,
port_forward=port_forward)
if tmux or screen: if tmux or screen:
attach_command_parts = ["ray attach", config_file] attach_command_parts = ["ray attach", config_file]
@ -400,7 +394,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
provider.cleanup() provider.cleanup()
def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None): def _exec(updater, cmd, screen, tmux, port_forward=None):
if cmd: if cmd:
if screen: if screen:
cmd = [ cmd = [
@ -418,7 +412,7 @@ def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
updater.ssh_cmd( updater.ssh_cmd(
cmd, cmd,
allocate_tty=True, allocate_tty=True,
expect_error=expect_error, exit_on_fail=True,
port_forward=port_forward) port_forward=port_forward)
@ -461,7 +455,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
rsync = updater.rsync_up rsync = updater.rsync_up
if source and target: if source and target:
rsync(source, target, check_error=False) rsync(source, target)
else: else:
updater.sync_file_mounts(rsync) updater.sync_file_mounts(rsync)

View file

@ -54,6 +54,7 @@ class NodeUpdater(object):
setup_commands, setup_commands,
runtime_hash, runtime_hash,
process_runner=subprocess, process_runner=subprocess,
exit_on_update_fail=False,
use_internal_ip=False): use_internal_ip=False):
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format( ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
@ -75,14 +76,9 @@ class NodeUpdater(object):
} }
self.initialization_commands = initialization_commands self.initialization_commands = initialization_commands
self.setup_commands = setup_commands self.setup_commands = setup_commands
self.exit_on_update_fail = exit_on_update_fail
self.runtime_hash = runtime_hash self.runtime_hash = runtime_hash
def get_caller(self, check_error):
if check_error:
return self.process_runner.call
else:
return self.process_runner.check_call
def get_node_ip(self): def get_node_ip(self):
if self.use_internal_ip: if self.use_internal_ip:
return self.provider.internal_ip(self.node_id) return self.provider.internal_ip(self.node_id)
@ -118,15 +114,21 @@ class NodeUpdater(object):
# the ControlPath directory exists, allowing SSH to maintain # the ControlPath directory exists, allowing SSH to maintain
# persistent sessions later on. # persistent sessions later on.
with open("/dev/null", "w") as redirect: with open("/dev/null", "w") as redirect:
self.get_caller(False)( try:
["mkdir", "-p", self.ssh_control_path], self.process_runner.check_call(
stdout=redirect, ["mkdir", "-p", self.ssh_control_path],
stderr=redirect) stdout=redirect,
stderr=redirect)
except subprocess.CalledProcessError as e:
logger.warning(e)
self.get_caller(False)( try:
["chmod", "0700", self.ssh_control_path], self.process_runner.check_call(
stdout=redirect, ["chmod", "0700", self.ssh_control_path],
stderr=redirect) stdout=redirect,
stderr=redirect)
except subprocess.CalledProcessError as e:
logger.warning(e)
def run(self): def run(self):
logger.info("NodeUpdater: " logger.info("NodeUpdater: "
@ -226,19 +228,19 @@ class NodeUpdater(object):
m = "{}: Initialization commands completed".format(self.node_id) m = "{}: Initialization commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)): with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.initialization_commands: for cmd in self.initialization_commands:
self.ssh_cmd(cmd) self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
m = "{}: Setup commands completed".format(self.node_id) m = "{}: Setup commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)): with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.setup_commands: for cmd in self.setup_commands:
self.ssh_cmd(cmd) self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
def rsync_up(self, source, target, redirect=None, check_error=True): def rsync_up(self, source, target, redirect=None):
logger.info("NodeUpdater: " logger.info("NodeUpdater: "
"{}: Syncing {} to {}...".format(self.node_id, source, "{}: Syncing {} to {}...".format(self.node_id, source,
target)) target))
self.set_ssh_ip_if_required() self.set_ssh_ip_if_required()
self.get_caller(check_error)( self.process_runner.check_call(
[ [
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options( "rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
self.ssh_private_key, 120, self.ssh_control_path)), "-avz", self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
@ -247,12 +249,12 @@ class NodeUpdater(object):
stdout=redirect or sys.stdout, stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr) stderr=redirect or sys.stderr)
def rsync_down(self, source, target, redirect=None, check_error=True): def rsync_down(self, source, target, redirect=None):
logger.info("NodeUpdater: " logger.info("NodeUpdater: "
"{}: Syncing {} from {}...".format(self.node_id, source, "{}: Syncing {} from {}...".format(self.node_id, source,
target)) target))
self.set_ssh_ip_if_required() self.set_ssh_ip_if_required()
self.get_caller(check_error)( self.process_runner.check_call(
[ [
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options( "rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
self.ssh_private_key, 120, self.ssh_control_path)), "-avz", self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
@ -267,7 +269,7 @@ class NodeUpdater(object):
redirect=None, redirect=None,
allocate_tty=False, allocate_tty=False,
emulate_interactive=True, emulate_interactive=True,
expect_error=False, exit_on_fail=False,
port_forward=None): port_forward=None):
self.set_ssh_ip_if_required() self.set_ssh_ip_if_required()
@ -291,12 +293,22 @@ class NodeUpdater(object):
"-L", "{}:localhost:{}".format(port_forward, port_forward) "-L", "{}:localhost:{}".format(port_forward, port_forward)
] ]
self.get_caller(expect_error)( final_cmd = ssh + ssh_opt + get_default_ssh_options(
ssh + ssh_opt + get_default_ssh_options( self.ssh_private_key, connect_timeout, self.ssh_control_path) + [
self.ssh_private_key, connect_timeout, self.ssh_control_path) + "{}@{}".format(self.ssh_user, self.ssh_ip), cmd
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd], ]
stdout=redirect or sys.stdout, try:
stderr=redirect or sys.stderr) self.process_runner.check_call(
final_cmd,
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
except subprocess.CalledProcessError:
if exit_on_fail:
logger.error("Command failed: \n\n {}\n".format(
" ".join(final_cmd)))
sys.exit(1)
else:
raise
class NodeUpdaterThread(NodeUpdater, Thread): class NodeUpdaterThread(NodeUpdater, Thread):