mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[autoscaler] Change sys.exit(1) in update ssh_cmd (#5266)
This commit is contained in:
parent
b3c8091a35
commit
1345802c39
2 changed files with 43 additions and 37 deletions
|
@ -374,13 +374,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
|
||||||
shutdown_cmd = wrap_docker(shutdown_cmd)
|
shutdown_cmd = wrap_docker(shutdown_cmd)
|
||||||
cmd += ("; {}; sudo shutdown -h now".format(shutdown_cmd))
|
cmd += ("; {}; sudo shutdown -h now".format(shutdown_cmd))
|
||||||
|
|
||||||
_exec(
|
_exec(updater, cmd, screen, tmux, port_forward=port_forward)
|
||||||
updater,
|
|
||||||
cmd,
|
|
||||||
screen,
|
|
||||||
tmux,
|
|
||||||
expect_error=stop,
|
|
||||||
port_forward=port_forward)
|
|
||||||
|
|
||||||
if tmux or screen:
|
if tmux or screen:
|
||||||
attach_command_parts = ["ray attach", config_file]
|
attach_command_parts = ["ray attach", config_file]
|
||||||
|
@ -400,7 +394,7 @@ def exec_cluster(config_file, cmd, docker, screen, tmux, stop, start,
|
||||||
provider.cleanup()
|
provider.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
|
def _exec(updater, cmd, screen, tmux, port_forward=None):
|
||||||
if cmd:
|
if cmd:
|
||||||
if screen:
|
if screen:
|
||||||
cmd = [
|
cmd = [
|
||||||
|
@ -418,7 +412,7 @@ def _exec(updater, cmd, screen, tmux, expect_error=False, port_forward=None):
|
||||||
updater.ssh_cmd(
|
updater.ssh_cmd(
|
||||||
cmd,
|
cmd,
|
||||||
allocate_tty=True,
|
allocate_tty=True,
|
||||||
expect_error=expect_error,
|
exit_on_fail=True,
|
||||||
port_forward=port_forward)
|
port_forward=port_forward)
|
||||||
|
|
||||||
|
|
||||||
|
@ -461,7 +455,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
|
||||||
rsync = updater.rsync_up
|
rsync = updater.rsync_up
|
||||||
|
|
||||||
if source and target:
|
if source and target:
|
||||||
rsync(source, target, check_error=False)
|
rsync(source, target)
|
||||||
else:
|
else:
|
||||||
updater.sync_file_mounts(rsync)
|
updater.sync_file_mounts(rsync)
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ class NodeUpdater(object):
|
||||||
setup_commands,
|
setup_commands,
|
||||||
runtime_hash,
|
runtime_hash,
|
||||||
process_runner=subprocess,
|
process_runner=subprocess,
|
||||||
|
exit_on_update_fail=False,
|
||||||
use_internal_ip=False):
|
use_internal_ip=False):
|
||||||
|
|
||||||
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
|
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
|
||||||
|
@ -75,14 +76,9 @@ class NodeUpdater(object):
|
||||||
}
|
}
|
||||||
self.initialization_commands = initialization_commands
|
self.initialization_commands = initialization_commands
|
||||||
self.setup_commands = setup_commands
|
self.setup_commands = setup_commands
|
||||||
|
self.exit_on_update_fail = exit_on_update_fail
|
||||||
self.runtime_hash = runtime_hash
|
self.runtime_hash = runtime_hash
|
||||||
|
|
||||||
def get_caller(self, check_error):
|
|
||||||
if check_error:
|
|
||||||
return self.process_runner.call
|
|
||||||
else:
|
|
||||||
return self.process_runner.check_call
|
|
||||||
|
|
||||||
def get_node_ip(self):
|
def get_node_ip(self):
|
||||||
if self.use_internal_ip:
|
if self.use_internal_ip:
|
||||||
return self.provider.internal_ip(self.node_id)
|
return self.provider.internal_ip(self.node_id)
|
||||||
|
@ -118,15 +114,21 @@ class NodeUpdater(object):
|
||||||
# the ControlPath directory exists, allowing SSH to maintain
|
# the ControlPath directory exists, allowing SSH to maintain
|
||||||
# persistent sessions later on.
|
# persistent sessions later on.
|
||||||
with open("/dev/null", "w") as redirect:
|
with open("/dev/null", "w") as redirect:
|
||||||
self.get_caller(False)(
|
try:
|
||||||
["mkdir", "-p", self.ssh_control_path],
|
self.process_runner.check_call(
|
||||||
stdout=redirect,
|
["mkdir", "-p", self.ssh_control_path],
|
||||||
stderr=redirect)
|
stdout=redirect,
|
||||||
|
stderr=redirect)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.warning(e)
|
||||||
|
|
||||||
self.get_caller(False)(
|
try:
|
||||||
["chmod", "0700", self.ssh_control_path],
|
self.process_runner.check_call(
|
||||||
stdout=redirect,
|
["chmod", "0700", self.ssh_control_path],
|
||||||
stderr=redirect)
|
stdout=redirect,
|
||||||
|
stderr=redirect)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.warning(e)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
logger.info("NodeUpdater: "
|
logger.info("NodeUpdater: "
|
||||||
|
@ -226,19 +228,19 @@ class NodeUpdater(object):
|
||||||
m = "{}: Initialization commands completed".format(self.node_id)
|
m = "{}: Initialization commands completed".format(self.node_id)
|
||||||
with LogTimer("NodeUpdater: {}".format(m)):
|
with LogTimer("NodeUpdater: {}".format(m)):
|
||||||
for cmd in self.initialization_commands:
|
for cmd in self.initialization_commands:
|
||||||
self.ssh_cmd(cmd)
|
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||||
|
|
||||||
m = "{}: Setup commands completed".format(self.node_id)
|
m = "{}: Setup commands completed".format(self.node_id)
|
||||||
with LogTimer("NodeUpdater: {}".format(m)):
|
with LogTimer("NodeUpdater: {}".format(m)):
|
||||||
for cmd in self.setup_commands:
|
for cmd in self.setup_commands:
|
||||||
self.ssh_cmd(cmd)
|
self.ssh_cmd(cmd, exit_on_fail=self.exit_on_update_fail)
|
||||||
|
|
||||||
def rsync_up(self, source, target, redirect=None, check_error=True):
|
def rsync_up(self, source, target, redirect=None):
|
||||||
logger.info("NodeUpdater: "
|
logger.info("NodeUpdater: "
|
||||||
"{}: Syncing {} to {}...".format(self.node_id, source,
|
"{}: Syncing {} to {}...".format(self.node_id, source,
|
||||||
target))
|
target))
|
||||||
self.set_ssh_ip_if_required()
|
self.set_ssh_ip_if_required()
|
||||||
self.get_caller(check_error)(
|
self.process_runner.check_call(
|
||||||
[
|
[
|
||||||
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
|
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
|
||||||
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
|
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
|
||||||
|
@ -247,12 +249,12 @@ class NodeUpdater(object):
|
||||||
stdout=redirect or sys.stdout,
|
stdout=redirect or sys.stdout,
|
||||||
stderr=redirect or sys.stderr)
|
stderr=redirect or sys.stderr)
|
||||||
|
|
||||||
def rsync_down(self, source, target, redirect=None, check_error=True):
|
def rsync_down(self, source, target, redirect=None):
|
||||||
logger.info("NodeUpdater: "
|
logger.info("NodeUpdater: "
|
||||||
"{}: Syncing {} from {}...".format(self.node_id, source,
|
"{}: Syncing {} from {}...".format(self.node_id, source,
|
||||||
target))
|
target))
|
||||||
self.set_ssh_ip_if_required()
|
self.set_ssh_ip_if_required()
|
||||||
self.get_caller(check_error)(
|
self.process_runner.check_call(
|
||||||
[
|
[
|
||||||
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
|
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
|
||||||
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
|
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
|
||||||
|
@ -267,7 +269,7 @@ class NodeUpdater(object):
|
||||||
redirect=None,
|
redirect=None,
|
||||||
allocate_tty=False,
|
allocate_tty=False,
|
||||||
emulate_interactive=True,
|
emulate_interactive=True,
|
||||||
expect_error=False,
|
exit_on_fail=False,
|
||||||
port_forward=None):
|
port_forward=None):
|
||||||
|
|
||||||
self.set_ssh_ip_if_required()
|
self.set_ssh_ip_if_required()
|
||||||
|
@ -291,12 +293,22 @@ class NodeUpdater(object):
|
||||||
"-L", "{}:localhost:{}".format(port_forward, port_forward)
|
"-L", "{}:localhost:{}".format(port_forward, port_forward)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.get_caller(expect_error)(
|
final_cmd = ssh + ssh_opt + get_default_ssh_options(
|
||||||
ssh + ssh_opt + get_default_ssh_options(
|
self.ssh_private_key, connect_timeout, self.ssh_control_path) + [
|
||||||
self.ssh_private_key, connect_timeout, self.ssh_control_path) +
|
"{}@{}".format(self.ssh_user, self.ssh_ip), cmd
|
||||||
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd],
|
]
|
||||||
stdout=redirect or sys.stdout,
|
try:
|
||||||
stderr=redirect or sys.stderr)
|
self.process_runner.check_call(
|
||||||
|
final_cmd,
|
||||||
|
stdout=redirect or sys.stdout,
|
||||||
|
stderr=redirect or sys.stderr)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
if exit_on_fail:
|
||||||
|
logger.error("Command failed: \n\n {}\n".format(
|
||||||
|
" ".join(final_cmd)))
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class NodeUpdaterThread(NodeUpdater, Thread):
|
class NodeUpdaterThread(NodeUpdater, Thread):
|
||||||
|
|
Loading…
Add table
Reference in a new issue