ray/scripts/pytest_checker.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

60 lines
1.8 KiB
Python
Raw Normal View History

import json
import re
import sys
from pathlib import Path
def check_file(file_contents: str) -> bool:
return bool(re.search(r"^if __name__ == \"__main__\":", file_contents, re.M))
def parse_json(data: str) -> dict:
return json.loads(data)
def treat_path(path: str) -> Path:
path = path[2:].replace(":", "/")
return Path(path)
def get_paths_from_parsed_data(parsed_data: dict) -> list:
paths = []
for rule in parsed_data["query"]["rule"]:
if "label" in rule and rule["label"]["@name"] == "main":
paths.append(treat_path(rule["label"]["@value"]))
else:
list_args = {e["@name"]: e for e in rule["list"]}
paths.append(treat_path(list_args["srcs"]["label"]["@value"]))
return paths
def main(data: str):
print("Checking files for the pytest snippet...")
parsed_data = parse_json(data)
paths = get_paths_from_parsed_data(parsed_data)
bad_paths = []
for path in paths:
print(f"Checking file {path}...")
with open(path, "r") as f:
if not check_file(f.read()):
print(f"File {path} is missing the pytest snippet.")
bad_paths.append(path)
if bad_paths:
raise RuntimeError(
'Found py_test files without `if __name__ == "__main__":` snippet:'
f" {[str(x) for x in bad_paths]}\n"
"If this is intentional, please add a `no_main` tag to bazel BUILD "
"entry for that file."
)
if __name__ == "__main__":
# Expects a json
# Invocation from workspace root:
# bazel query 'kind(py_test.*, tests(python/...) intersect
# attr(tags, "\bteam:ml\b", python/...) except attr(tags, "\bno_main\b",
# python/...))' --output xml | xq | python scripts/pytest_checker.py
data = sys.stdin.read()
main(data)