From 75b2fc4313eb3f9c71640dc411cbf378b1bc0d6f Mon Sep 17 00:00:00 2001 From: mehrdadn Date: Thu, 23 Jul 2020 16:07:00 -0700 Subject: [PATCH] Auto-cancel build when a new commit is pushed (#8043) Co-authored-by: Mehrdad --- .github/workflows/main.yml | 1 + .travis.yml | 1 + ci/remote-watch.py | 261 +++++++++++++++++++++++++++++++++++++ 3 files changed, 263 insertions(+) create mode 100755 ci/remote-watch.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e794b7c96..c5f6f8878 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -91,6 +91,7 @@ jobs: RAY_DEFAULT_BUILD: 1 WINDOWS_WHEELS: 1 run: | + python -u ci/remote-watch.py --skip_repo=ray-project/ray & . ./ci/travis/ci.sh init . ./ci/travis/ci.sh build . ./ci/travis/ci.sh upload_wheels || true diff --git a/.travis.yml b/.travis.yml index 5646399a7..2856d6839 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,6 +16,7 @@ before_install: if [ false != "${TRAVIS_PULL_REQUEST-}" ]; then to_fetch+=("+refs/pull/${TRAVIS_PULL_REQUEST}/merge:"); fi git fetch -q -- origin "${to_fetch[@]}" git checkout -qf "${TRAVIS_COMMIT}" -- + python -u ci/remote-watch.py --skip_repo=ray-project/ray & matrix: include: diff --git a/ci/remote-watch.py b/ci/remote-watch.py new file mode 100755 index 000000000..c1c20d8bf --- /dev/null +++ b/ci/remote-watch.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python + +""" +This command must be run in a git repository. + +It watches the remote branch for changes, killing the given PID when it detects +that the remote branch no longer points to the local commit. + +If the commit message contains a line saying "CI_KEEP_ALIVE", then killing +will not occur until the branch is deleted from the remote. + +If no PID is given, then the entire process group of this process is killed. +""" + +# Prefer to keep this file Python 2-compatible so that it can easily run early +# in the CI process on any system. + +import argparse +import errno +import logging +import os +import signal +import subprocess +import sys +import time + +logger = logging.getLogger(__name__) + +GITHUB = "GitHub" +TRAVIS = "Travis" + + +def git(*args): + cmdline = ["git"] + list(args) + return subprocess.check_output(cmdline).decode("utf-8").rstrip() + + +def get_current_ci(): + if "GITHUB_WORKFLOW" in os.environ: + return GITHUB + elif "TRAVIS" in os.environ: + return TRAVIS + return None + + +def get_ci_event_name(): + ci = get_current_ci() + if ci == GITHUB: + return os.environ["GITHUB_EVENT_NAME"] + elif ci == TRAVIS: + return os.environ["TRAVIS_EVENT_TYPE"] + return None + + +def get_repo_slug(): + ci = get_current_ci() + if ci == GITHUB: + return os.environ["GITHUB_REPOSITORY"] + elif ci == TRAVIS: + return os.environ["TRAVIS_REPO_SLUG"] + return None + + +def get_remote_url(remote): + return git("ls-remote", "--get-url", remote) + + +def git_remote_branch_info(): + """Obtains the remote branch name, remote name, and commit hash that + correspond to the local HEAD. + + Returns: + ("refs/heads/mybranch", "origin", "1A2B3C4...") + """ + ref = None + remote = None + expected_sha = git("rev-parse", "--verify", "HEAD") + + try: + # Try to get the local branch ref. (e.g. refs/heads/mybranch) + head = git("symbolic-ref", "-q", "HEAD") + # Try to get the remotely tracked ref, if any. (e.g. origin/mybranch) + ref = git("for-each-ref", "--format=%(upstream:short)", head) + except subprocess.CalledProcessError: + pass + + if ref: + (remote, ref) = ref.split("/", 1) + ref = "refs/heads/" + ref + else: + remote = git("remote", "show", "-n").splitlines()[0] + ci = get_current_ci() + if ci == TRAVIS: + travis_pr = os.getenv("TRAVIS_PULL_REQUEST") + if travis_pr is not None: + ref = "refs/pull/{}/merge".format(travis_pr) + else: + ref = "refs/heads/{}".format(os.environ["TRAVIS_BRANCH"]) + expected_sha = os.getenv("TRAVIS_COMMIT") + elif ci == GITHUB: + ref = os.getenv("GITHUB_REF") + + if not remote or not ref: + raise ValueError("Invalid remote {!r} or ref {!r}".format(remote, ref)) + + return (ref, remote, expected_sha) + + +def get_commit_metadata(hash): + """Get the commit info (content hash, parents, message, etc.) as a list of + key-value pairs. + """ + info = git("cat-file", "-p", hash) + parts = info.split("\n\n", 1) # Split off the commit message + records = parts[0] + message = parts[1] if len(parts) > 1 else None + result = [] + records = records.replace("\n ", "\0 ") # Join multiple lines into one + for record in records.splitlines(True): + (key, value) = record.split(" ", 1) + value = value.replace("\0 ", "\n ") # Re-split lines + result.append((key, value)) + result.append(("message", message)) + return result + + +def terminate_my_process_group(): + result = 0 + timeout = 15 + try: + logger.warning("Attempting kill...") + if sys.platform == "win32": + os.kill(0, signal.CTRL_BREAK_EVENT) # This might get ignored. + time.sleep(timeout) + os.kill(os.getppid(), signal.SIGTERM) + else: + # This SIGTERM seems to be needed to prevent jobs from lingering. + os.kill(os.getppid(), signal.SIGTERM) + time.sleep(timeout) + os.kill(0, signal.SIGKILL) + except OSError as ex: + if ex.errno not in (errno.EBADF, errno.ESRCH): + raise + logger.error("Kill error %s: %s", ex.errno, ex.strerror) + result = ex.errno + return result + + +def yield_poll_schedule(): + schedule = [0, 5, 5, 10, 20, 40, 40] + [60] * 5 + [120] * 10 + [300, 600] + for item in schedule: + yield item + while True: + yield schedule[-1] + + +def detect_spurious_commit(actual, expected, remote): + """GitHub sometimes spuriously generates commits multiple times with + different dates but identical contents. See here: + https://github.com/travis-ci/travis-ci/issues/7459#issuecomment-601346831 + We need to detect whether this might be the case, and we do so by + comparing the commits' contents ("tree" objects) and their parents. + + Args: + actual: The commit line on the remote from git ls-remote, e.g.: + da39a3ee5e6b4b0d3255bfef95601890afd80709 refs/heads/master + expected: The commit line initially expected. + + Returns: + The new (actual) commit line, if it is suspected to be spurious. + Otherwise, the previously expected commit line. + """ + actual_hash = actual.split(None, 1)[0] + expected_hash = expected.split(None, 1)[0] + relevant = ["tree", "parent"] # relevant parts of a commit for comparison + if actual != expected: + git("fetch", "-q", remote, actual_hash) + actual_info = get_commit_metadata(actual_hash) + expected_info = get_commit_metadata(expected_hash) + a = [pair for pair in actual_info if pair[0] in relevant] + b = [pair for pair in expected_info if pair[0] in relevant] + if a == b: + expected = actual + return expected + + +def should_keep_alive(commit_msg): + result = False + ci = get_current_ci() or "" + for line in commit_msg.splitlines(): + parts = line.strip("# ").split(":", 1) + (key, val) = parts if len(parts) > 1 else (parts[0], "") + if key == "CI_KEEP_ALIVE": + ci_names = val.replace(",", " ").lower().split() if val else [] + if len(ci_names) == 0 or ci.lower() in ci_names: + result = True + return result + + +def monitor(): + (ref, remote, expected_sha) = git_remote_branch_info() + expected_line = "{}\t{}".format(expected_sha, ref) + + if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")): + logger.info("Not monitoring %s on %s due to keep-alive on: %s", + ref, remote, expected_line) + return + + logger.info("Monitoring %s (%s) for changes in %s: %s", remote, + get_remote_url(remote), ref, expected_line) + + for to_wait in yield_poll_schedule(): + time.sleep(to_wait) + status = 0 + line = None + try: + # Query the commit on the remote ref (without fetching the commit). + line = git("ls-remote", "--exit-code", remote, ref) + except subprocess.CalledProcessError as ex: + status = ex.returncode + + if status == 2: + logger.info("Terminating job as %s has been deleted on %s: %s", + ref, remote, expected_line) + break + elif status != 0: + logger.error("Error %d: unable to check %s on %s: %s", + status, ref, remote, expected_line) + else: + expected_line = detect_spurious_commit(line, expected_line, remote) + if expected_line != line: + logger.info("Terminating job as %s has been updated on %s\n" + " from:\t%s\n" + " to: \t%s", ref, remote, expected_line, line) + time.sleep(1) # wait for CI to flush output + break + + return terminate_my_process_group() + + +def main(program, *args): + p = argparse.ArgumentParser() + p.add_argument("--skip_repo", action="append", help="Repo to exclude.") + parsed_args = p.parse_args(args) + skipped_repos = parsed_args.skip_repo or [] + repo_slug = get_repo_slug() + event_name = get_ci_event_name() + if repo_slug not in skipped_repos or event_name == "pull_request": + result = monitor() + else: + logger.info("Skipping monitoring %s %s build", repo_slug, event_name) + return result + + +if __name__ == "__main__": + logging.basicConfig(format="%(levelname)s: %(message)s", + stream=sys.stderr, level=logging.DEBUG) + try: + raise SystemExit(main(*sys.argv) or 0) + except KeyboardInterrupt: + pass