#!/usr/bin/env python
"""
This command must be run in a git repository.

It watches the remote branch for changes, killing the given PID when it detects
that the remote branch no longer points to the local commit.

If the commit message contains a line saying "CI_KEEP_ALIVE", then killing
will not occur until the branch is deleted from the remote.

If no PID is given, then the entire process group of this process is killed.
"""

# Prefer to keep this file Python 2-compatible so that it can easily run early
# in the CI process on any system.

import argparse
import errno
import logging
import os
import signal
import subprocess
import sys
import time

logger = logging.getLogger(__name__)

GITHUB = "GitHub"
TRAVIS = "Travis"


def git(*args):
    cmdline = ["git"] + list(args)
    return subprocess.check_output(cmdline).decode("utf-8").rstrip()


def get_current_ci():
    if "GITHUB_WORKFLOW" in os.environ:
        return GITHUB
    elif "TRAVIS" in os.environ:
        return TRAVIS
    return None


def get_ci_event_name():
    ci = get_current_ci()
    if ci == GITHUB:
        return os.environ["GITHUB_EVENT_NAME"]
    elif ci == TRAVIS:
        return os.environ["TRAVIS_EVENT_TYPE"]
    return None


def get_repo_slug():
    ci = get_current_ci()
    if ci == GITHUB:
        return os.environ["GITHUB_REPOSITORY"]
    elif ci == TRAVIS:
        return os.environ["TRAVIS_REPO_SLUG"]
    return None


def get_remote_url(remote):
    return git("ls-remote", "--get-url", remote)


def replace_suffix(base, old_suffix, new_suffix=""):
    if base.endswith(old_suffix):
        base = base[:len(base) - len(old_suffix)] + new_suffix
    return base


def git_branch_info_to_track():
    """Obtains the remote branch name, remote name, and commit hash that
    should be tracked for changes.

    Returns:
        ("refs/heads/mybranch", "origin", "1A2B3C4...")
    """
    expected_sha = None
    ref = None
    remote = git("remote", "show", "-n").splitlines()[0]

    ci = get_current_ci()
    if ci == GITHUB:
        expected_sha = os.getenv("GITHUB_HEAD_SHA") or os.environ["GITHUB_SHA"]
        ref = replace_suffix(os.environ["GITHUB_REF"], "/merge", "/head")
    elif ci == TRAVIS:
        pr = os.getenv("TRAVIS_PULL_REQUEST", "false")
        if pr != "false":
            expected_sha = os.environ["TRAVIS_PULL_REQUEST_SHA"]
            ref = "refs/pull/{}/head".format(pr)
        else:
            expected_sha = os.environ["TRAVIS_COMMIT"]
            ref = "refs/heads/{}".format(os.environ["TRAVIS_BRANCH"])

    result = (ref, remote, expected_sha)

    if not all(result):
        msg = "Invalid remote {!r}, ref {!r}, or hash {!r} for CI {!r}"
        raise ValueError(msg.format(remote, ref, expected_sha, ci))

    return result


def get_commit_metadata(hash):
    """Get the commit info (content hash, parents, message, etc.) as a list of
    key-value pairs.
    """
    info = git("cat-file", "-p", hash)
    parts = info.split("\n\n", 1)  # Split off the commit message
    records = parts[0]
    message = parts[1] if len(parts) > 1 else None
    result = []
    records = records.replace("\n ", "\0 ")  # Join multiple lines into one
    for record in records.splitlines(True):
        (key, value) = record.split(" ", 1)
        value = value.replace("\0 ", "\n ")  # Re-split lines
        result.append((key, value))
    result.append(("message", message))
    return result


def terminate_my_process_group():
    result = 0
    timeout = 15
    try:
        logger.warning("Attempting kill...")
        if sys.platform == "win32":
            os.kill(0, signal.CTRL_BREAK_EVENT)  # This might get ignored.
            time.sleep(timeout)
            os.kill(os.getppid(), signal.SIGTERM)
        else:
            # This SIGTERM seems to be needed to prevent jobs from lingering.
            os.kill(os.getppid(), signal.SIGTERM)
            time.sleep(timeout)
            os.kill(0, signal.SIGKILL)
    except OSError as ex:
        if ex.errno not in (errno.EBADF, errno.ESRCH):
            raise
        logger.error("Kill error %s: %s", ex.errno, ex.strerror)
        result = ex.errno
    return result


def yield_poll_schedule():
    schedule = [0, 5, 5, 10, 20, 40, 40] + [60] * 5 + [120] * 10 + [300]
    for item in schedule:
        yield item
    while True:
        yield schedule[-1]


def detect_spurious_commit(actual, expected, remote):
    """GitHub sometimes spuriously generates commits multiple times with
    different dates but identical contents. See here:
    https://github.com/travis-ci/travis-ci/issues/7459#issuecomment-601346831
    We need to detect whether this might be the case, and we do so by
    comparing the commits' contents ("tree" objects) and their parents.

    Args:
        actual: The commit line on the remote from git ls-remote, e.g.:
            da39a3ee5e6b4b0d3255bfef95601890afd80709    refs/heads/master
        expected: The commit line initially expected.

    Returns:
        The new (actual) commit line, if it is suspected to be spurious.
        Otherwise, the previously expected commit line.
    """
    actual_hash = actual.split(None, 1)[0]
    expected_hash = expected.split(None, 1)[0]
    relevant = ["tree", "parent"]  # relevant parts of a commit for comparison
    if actual != expected:
        git("fetch", "-q", remote, actual_hash)
        actual_info = get_commit_metadata(actual_hash)
        expected_info = get_commit_metadata(expected_hash)
        a = [pair for pair in actual_info if pair[0] in relevant]
        b = [pair for pair in expected_info if pair[0] in relevant]
        if a == b:
            expected = actual
    return expected


def should_keep_alive(commit_msg):
    result = False
    ci = get_current_ci() or ""
    for line in commit_msg.splitlines():
        parts = line.strip("# ").split(":", 1)
        (key, val) = parts if len(parts) > 1 else (parts[0], "")
        if key == "CI_KEEP_ALIVE":
            ci_names = val.replace(",", " ").lower().split() if val else []
            if len(ci_names) == 0 or ci.lower() in ci_names:
                result = True
    return result


def monitor():
    (ref, remote, expected_sha) = git_branch_info_to_track()
    expected_line = "{}\t{}".format(expected_sha, ref)

    if should_keep_alive(git("show", "-s", "--format=%B", "HEAD^-")):
        logger.info("Not monitoring %s on %s due to keep-alive on: %s", ref,
                    remote, expected_line)
        return

    logger.info("Monitoring %s (%s) for changes in %s: %s", remote,
                get_remote_url(remote), ref, expected_line)

    for to_wait in yield_poll_schedule():
        time.sleep(to_wait)
        status = 0
        line = None
        try:
            # Query the commit on the remote ref (without fetching the commit).
            line = git("ls-remote", "--exit-code", remote, ref)
        except subprocess.CalledProcessError as ex:
            status = ex.returncode

        if status == 2:
            logger.info("Terminating job as %s has been deleted on %s: %s",
                        ref, remote, expected_line)
            break
        elif status != 0:
            logger.error("Error %d: unable to check %s on %s: %s", status, ref,
                         remote, expected_line)
        else:
            prev = expected_line
            expected_line = detect_spurious_commit(line, expected_line, remote)
            if expected_line != line:
                logger.info(
                    "Terminating job as %s has been updated on %s\n"
                    "    from:\t%s\n"
                    "    to:  \t%s", ref, remote, expected_line, line)
                time.sleep(1)  # wait for CI to flush output
                break
            if expected_line != prev:
                logger.info(
                    "%s appeared to spuriously change on %s\n"
                    "    from:\t%s\n"
                    "    to:  \t%s", ref, remote, prev, expected_line)

    return terminate_my_process_group()


def main(program, *args):
    p = argparse.ArgumentParser()
    p.add_argument("--skip_repo", action="append", help="Repo to exclude.")
    parsed_args = p.parse_args(args)
    skipped_repos = parsed_args.skip_repo or []
    repo_slug = get_repo_slug()
    event_name = get_ci_event_name()
    if repo_slug not in skipped_repos or event_name == "pull_request":
        result = monitor()
    else:
        logger.info("Skipping monitoring %s %s build", repo_slug, event_name)
    return result


if __name__ == "__main__":
    logging.basicConfig(
        format="%(levelname)s: %(message)s",
        stream=sys.stderr,
        level=logging.DEBUG)
    try:
        raise SystemExit(main(*sys.argv) or 0)
    except KeyboardInterrupt:
        pass