[autoscaler] Change sys.exit(1) in update ssh_cmd (#5266)

This commit is contained in:
Kristian Hartikainen 2019-07-31 13:45:05 -07:00 committed by Richard Liaw
parent b3c8091a35
commit 1345802c39
2 changed files with 43 additions and 37 deletions

View file

@ -374,13 +374,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
shutdown_cmd = wrap_docker(shutdown_cmd)
cmd += ("; {}; sudo shutdown -h now".format(shutdown_cmd))
_exec(
updater,
cmd,
screen,
tmux,
expect_error=stop,
port_forward=port_forward)
_exec(updater, cmd, screen, tmux, port_forward=port_forward)
if tmux or screen:
attach_command_parts = ["ray attach", config_file]
@ -400,7 +394,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
provider.cleanup()
def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
def _exec(updater, cmd, screen, tmux, port_forward=None):
if cmd:
if screen:
cmd = [
@ -418,7 +412,7 @@ def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
updater.ssh_cmd(
cmd,
allocate_tty=True,
expect_error=expect_error,
exit_on_fail=True,
port_forward=port_forward)
@ -461,7 +455,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
rsync = updater.rsync_up
if source and target:
rsync(source, target, check_error=False)
rsync(source, target)
else:
updater.sync_file_mounts(rsync)

View file

@ -54,6 +54,7 @@ class NodeUpdater(object):
setup_commands,
runtime_hash,
process_runner=subprocess,
exit_on_update_fail=False,
use_internal_ip=False):
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
@ -75,14 +76,9 @@ class NodeUpdater(object):
}
self.initialization_commands = initialization_commands
self.setup_commands = setup_commands
self.exit_on_update_fail = exit_on_update_fail
self.runtime_hash = runtime_hash
def get_caller(self, check_error):
if check_error:
return self.process_runner.call
else:
return self.process_runner.check_call
def get_node_ip(self):
if self.use_internal_ip:
return self.provider.internal_ip(self.node_id)
@ -118,15 +114,21 @@ class NodeUpdater(object):
# the ControlPath directory exists, allowing SSH to maintain
# persistent sessions later on.
with open("/dev/null", "w") as redirect:
self.get_caller(False)(
["mkdir", "-p", self.ssh_control_path],
stdout=redirect,
stderr=redirect)
try:
self.process_runner.check_call(
["mkdir", "-p", self.ssh_control_path],
stdout=redirect,
stderr=redirect)
except subprocess.CalledProcessError as e:
logger.warning(e)
self.get_caller(False)(
["chmod", "0700", self.ssh_control_path],
stdout=redirect,
stderr=redirect)
try:
self.process_runner.check_call(
["chmod", "0700", self.ssh_control_path],
stdout=redirect,
stderr=redirect)
except subprocess.CalledProcessError as e:
logger.warning(e)
def run(self):
logger.info("NodeUpdater: "
@ -226,19 +228,19 @@ class NodeUpdater(object):
m = "{}: Initialization commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.initialization_commands:
self.ssh_cmd(cmd)
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
m = "{}: Setup commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.setup_commands:
self.ssh_cmd(cmd)
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
def rsync_up(self, source, target, redirect=None, check_error=True):
def rsync_up(self, source, target, redirect=None):
logger.info("NodeUpdater: "
"{}: Syncing {} to {}...".format(self.node_id, source,
target))
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
self.process_runner.check_call(
[
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
@ -247,12 +249,12 @@ class NodeUpdater(object):
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
def rsync_down(self, source, target, redirect=None, check_error=True):
def rsync_down(self, source, target, redirect=None):
logger.info("NodeUpdater: "
"{}: Syncing {} from {}...".format(self.node_id, source,
target))
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
self.process_runner.check_call(
[
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
@ -267,7 +269,7 @@ class NodeUpdater(object):
redirect=None,
allocate_tty=False,
emulate_interactive=True,
expect_error=False,
exit_on_fail=False,
port_forward=None):
self.set_ssh_ip_if_required()
@ -291,12 +293,22 @@ class NodeUpdater(object):
"-L", "{}:localhost:{}".format(port_forward, port_forward)
]
self.get_caller(expect_error)(
ssh + ssh_opt + get_default_ssh_options(
self.ssh_private_key, connect_timeout, self.ssh_control_path) +
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd],
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
final_cmd = ssh + ssh_opt + get_default_ssh_options(
self.ssh_private_key, connect_timeout, self.ssh_control_path) + [
"{}@{}".format(self.ssh_user, self.ssh_ip), cmd
]
try:
self.process_runner.check_call(
final_cmd,
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
except subprocess.CalledProcessError:
if exit_on_fail:
logger.error("Command failed: \n\n {}\n".format(
" ".join(final_cmd)))
sys.exit(1)
else:
raise
class NodeUpdaterThread(NodeUpdater, Thread):