mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
295 lines
9 KiB
Python
Executable file
295 lines
9 KiB
Python
Executable file
#!/usr/bin/env python
|
|
"""
|
|
This command must be run in a git repository.
|
|
|
|
It watches the remote branch for changes, killing the given PID when it detects
|
|
that the remote branch no longer points to the local commit.
|
|
|
|
If the commit message contains a line saying "CI_KEEP_ALIVE", then killing
|
|
will not occur until the branch is deleted from the remote.
|
|
|
|
If no PID is given, then the entire process group of this process is killed.
|
|
"""
|
|
|
|
# Prefer to keep this file Python 2-compatible so that it can easily run early
|
|
# in the CI process on any system.
|
|
|
|
import argparse
|
|
import errno
|
|
import logging
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GITHUB = "GitHub"
|
|
TRAVIS = "Travis"
|
|
|
|
|
|
def git(*args):
|
|
cmdline = ["git"] + list(args)
|
|
return subprocess.check_output(cmdline).decode("utf-8").rstrip()
|
|
|
|
|
|
def get_current_ci():
|
|
if "GITHUB_WORKFLOW" in os.environ:
|
|
return GITHUB
|
|
elif "TRAVIS" in os.environ:
|
|
return TRAVIS
|
|
return None
|
|
|
|
|
|
def get_ci_event_name():
|
|
ci = get_current_ci()
|
|
if ci == GITHUB:
|
|
return os.environ["GITHUB_EVENT_NAME"]
|
|
elif ci == TRAVIS:
|
|
return os.environ["TRAVIS_EVENT_TYPE"]
|
|
return None
|
|
|
|
|
|
def get_repo_slug():
|
|
ci = get_current_ci()
|
|
if ci == GITHUB:
|
|
return os.environ["GITHUB_REPOSITORY"]
|
|
elif ci == TRAVIS:
|
|
return os.environ["TRAVIS_REPO_SLUG"]
|
|
return None
|
|
|
|
|
|
def get_remote_url(remote):
|
|
return git("ls-remote", "--get-url", remote)
|
|
|
|
|
|
def replace_suffix(base, old_suffix, new_suffix=""):
|
|
if base.endswith(old_suffix):
|
|
base = base[: len(base) - len(old_suffix)] + new_suffix
|
|
return base
|
|
|
|
|
|
def git_branch_info_to_track():
|
|
"""Obtains the remote branch name, remote name, and commit hash that
|
|
should be tracked for changes.
|
|
|
|
Returns:
|
|
("refs/heads/mybranch", "origin", "1A2B3C4...")
|
|
"""
|
|
expected_sha = None
|
|
ref = None
|
|
remote = git("remote", "show", "-n").splitlines()[0]
|
|
|
|
ci = get_current_ci()
|
|
if ci == GITHUB:
|
|
expected_sha = os.getenv("GITHUB_HEAD_SHA") or os.environ["GITHUB_SHA"]
|
|
ref = replace_suffix(os.environ["GITHUB_REF"], "/merge", "/head")
|
|
elif ci == TRAVIS:
|
|
pr = os.getenv("TRAVIS_PULL_REQUEST", "false")
|
|
if pr != "false":
|
|
expected_sha = os.environ["TRAVIS_PULL_REQUEST_SHA"]
|
|
ref = "refs/pull/{}/head".format(pr)
|
|
else:
|
|
expected_sha = os.environ["TRAVIS_COMMIT"]
|
|
ref = "refs/heads/{}".format(os.environ["TRAVIS_BRANCH"])
|
|
|
|
result = (ref, remote, expected_sha)
|
|
|
|
if not all(result):
|
|
msg = "Invalid remote {!r}, ref {!r}, or hash {!r} for CI {!r}"
|
|
raise ValueError(msg.format(remote, ref, expected_sha, ci))
|
|
|
|
return result
|
|
|
|
|
|
def get_commit_metadata(hash):
|
|
"""Get the commit info (content hash, parents, message, etc.) as a list of
|
|
key-value pairs.
|
|
"""
|
|
info = git("cat-file", "-p", hash)
|
|
parts = info.split("\n\n", 1) # Split off the commit message
|
|
records = parts[0]
|
|
message = parts[1] if len(parts) > 1 else None
|
|
result = []
|
|
records = records.replace("\n ", "\0 ") # Join multiple lines into one
|
|
for record in records.splitlines(True):
|
|
(key, value) = record.split(" ", 1)
|
|
value = value.replace("\0 ", "\n ") # Re-split lines
|
|
result.append((key, value))
|
|
result.append(("message", message))
|
|
return result
|
|
|
|
|
|
def terminate_my_process_group():
|
|
result = 0
|
|
timeout = 15
|
|
try:
|
|
logger.warning("Attempting kill...")
|
|
if sys.platform == "win32":
|
|
os.kill(0, signal.CTRL_BREAK_EVENT) # This might get ignored.
|
|
time.sleep(timeout)
|
|
os.kill(os.getppid(), signal.SIGTERM)
|
|
else:
|
|
# This SIGTERM seems to be needed to prevent jobs from lingering.
|
|
os.kill(os.getppid(), signal.SIGTERM)
|
|
time.sleep(timeout)
|
|
os.kill(0, signal.SIGKILL)
|
|
except OSError as ex:
|
|
if ex.errno not in (errno.EBADF, errno.ESRCH):
|
|
raise
|
|
logger.error("Kill error %s: %s", ex.errno, ex.strerror)
|
|
result = ex.errno
|
|
return result
|
|
|
|
|
|
def yield_poll_schedule():
|
|
schedule = [0, 5, 5, 10, 20, 40, 40] + [60] * 5 + [120] * 10 + [300]
|
|
for item in schedule:
|
|
yield item
|
|
while True:
|
|
yield schedule[-1]
|
|
|
|
|
|
def detect_spurious_commit(actual, expected, remote):
|
|
"""GitHub sometimes spuriously generates commits multiple times with
|
|
different dates but identical contents. See here:
|
|
https://github.com/travis-ci/travis-ci/issues/7459#issuecomment-601346831
|
|
We need to detect whether this might be the case, and we do so by
|
|
comparing the commits' contents ("tree" objects) and their parents.
|
|
|
|
Args:
|
|
actual: The commit line on the remote from git ls-remote, e.g.:
|
|
da39a3ee5e6b4b0d3255bfef95601890afd80709 refs/heads/master
|
|
expected: The commit line initially expected.
|
|
|
|
Returns:
|
|
The new (actual) commit line, if it is suspected to be spurious.
|
|
Otherwise, the previously expected commit line.
|
|
"""
|
|
actual_hash = actual.split(None, 1)[0]
|
|
expected_hash = expected.split(None, 1)[0]
|
|
relevant = ["tree", "parent"] # relevant parts of a commit for comparison
|
|
if actual != expected:
|
|
git("fetch", "-q", remote, actual_hash)
|
|
actual_info = get_commit_metadata(actual_hash)
|
|
expected_info = get_commit_metadata(expected_hash)
|
|
a = [pair for pair in actual_info if pair[0] in relevant]
|
|
b = [pair for pair in expected_info if pair[0] in relevant]
|
|
if a == b:
|
|
expected = actual
|
|
return expected
|
|
|
|
|
|
def should_keep_alive(commit_msg):
|
|
result = False
|
|
ci = get_current_ci() or ""
|
|
for line in commit_msg.splitlines():
|
|
parts = line.strip("# ").split(":", 1)
|
|
(key, val) = parts if len(parts) > 1 else (parts[0], "")
|
|
if key == "CI_KEEP_ALIVE":
|
|
ci_names = val.replace(",", " ").lower().split() if val else []
|
|
if len(ci_names) == 0 or ci.lower() in ci_names:
|
|
result = True
|
|
return result
|
|
|
|
|
|
def monitor():
|
|
(ref, remote, expected_sha) = git_branch_info_to_track()
|
|
expected_line = "{}\t{}".format(expected_sha, ref)
|
|
|
|
if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
|
|
logger.info(
|
|
"Not monitoring %s on %s due to keep-alive on: %s",
|
|
ref,
|
|
remote,
|
|
expected_line,
|
|
)
|
|
return
|
|
|
|
logger.info(
|
|
"Monitoring %s (%s) for changes in %s: %s",
|
|
remote,
|
|
get_remote_url(remote),
|
|
ref,
|
|
expected_line,
|
|
)
|
|
|
|
for to_wait in yield_poll_schedule():
|
|
time.sleep(to_wait)
|
|
status = 0
|
|
line = None
|
|
try:
|
|
# Query the commit on the remote ref (without fetching the commit).
|
|
line = git("ls-remote", "--exit-code", remote, ref)
|
|
except subprocess.CalledProcessError as ex:
|
|
status = ex.returncode
|
|
|
|
if status == 2:
|
|
logger.info(
|
|
"Terminating job as %s has been deleted on %s: %s",
|
|
ref,
|
|
remote,
|
|
expected_line,
|
|
)
|
|
break
|
|
elif status != 0:
|
|
logger.error(
|
|
"Error %d: unable to check %s on %s: %s",
|
|
status,
|
|
ref,
|
|
remote,
|
|
expected_line,
|
|
)
|
|
else:
|
|
prev = expected_line
|
|
expected_line = detect_spurious_commit(line, expected_line, remote)
|
|
if expected_line != line:
|
|
logger.info(
|
|
"Terminating job as %s has been updated on %s\n"
|
|
" from:\t%s\n"
|
|
" to: \t%s",
|
|
ref,
|
|
remote,
|
|
expected_line,
|
|
line,
|
|
)
|
|
time.sleep(1) # wait for CI to flush output
|
|
break
|
|
if expected_line != prev:
|
|
logger.info(
|
|
"%s appeared to spuriously change on %s\n"
|
|
" from:\t%s\n"
|
|
" to: \t%s",
|
|
ref,
|
|
remote,
|
|
prev,
|
|
expected_line,
|
|
)
|
|
|
|
return terminate_my_process_group()
|
|
|
|
|
|
def main(program, *args):
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--skip_repo", action="append", help="Repo to exclude.")
|
|
parsed_args = p.parse_args(args)
|
|
skipped_repos = parsed_args.skip_repo or []
|
|
repo_slug = get_repo_slug()
|
|
event_name = get_ci_event_name()
|
|
if repo_slug not in skipped_repos or event_name == "pull_request":
|
|
result = monitor()
|
|
else:
|
|
logger.info("Skipping monitoring %s %s build", repo_slug, event_name)
|
|
result = 0
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(
|
|
format="%(levelname)s: %(message)s", stream=sys.stderr, level=logging.DEBUG
|
|
)
|
|
try:
|
|
raise SystemExit(main(*sys.argv) or 0)
|
|
except KeyboardInterrupt:
|
|
pass
|