ray/scripts/pytest_checker.py

126 lines
4.1 KiB
Python

import json
import re
import sys
from pathlib import Path
def check_file(file_contents: str) -> bool:
"""Check file for the snippet"""
return bool(re.search(r"^if __name__ == \"__main__\":", file_contents, re.M))
def parse_json(data: str) -> dict:
return json.loads(data)
def treat_path(path: str) -> Path:
"""Treat bazel paths to filesystem paths"""
path = path[2:].replace(":", "/")
return Path(path)
def get_paths_from_parsed_data(parsed_data: dict) -> list:
# Example JSON input:
# "rule": [
# {
# "@class": "py_test",
# "@location": "/home/ubuntu/ray/python/ray/tests/BUILD:345:8",
# "@name": "//python/ray/tests:test_tracing",
# "string": [
# {
# "@name": "name",
# "@value": "test_tracing"
# },
# ],
# "list": [
# {
# "@name": "srcs",
# "label": [
# {
# "@value": "//python/ray/tests:aws/conftest.py"
# },
# {
# "@value": "//python/ray/tests:conftest.py"
# },
# {
# "@value": "//python/ray/tests:test_tracing.py"
# }
# ]
# }
# ],
# ... other fields ...
# "label": {
# "@name": "main",
# "@value": "//python/ray/tests:test_runtime_env_working_dir_remote_uri.py"
# },
# ... other fields ...
# }
# ]
#
# We want to get the location of the actual test file.
# This can be, in order of priority:
# 1. Specified as the "main" label
# 2. Specified as the ONLY "srcs" label
# 3. Specified as the "srcs" label matching the "name" of the test
# https://docs.bazel.build/versions/main/be/python.html#py_test
paths = []
for rule in parsed_data["query"]["rule"]:
name = rule["@name"]
if "label" in rule and rule["label"]["@name"] == "main":
paths.append((name, treat_path(rule["label"]["@value"])))
else:
list_args = {e["@name"]: e for e in rule["list"]}
label = list_args["srcs"]["label"]
if isinstance(label, dict):
paths.append((name, treat_path(label["@value"])))
else:
# list
string_name = next(
x["@value"] for x in rule["string"] if x["@name"] == "name"
)
main_path = next(
x["@value"] for x in label if string_name in x["@value"]
)
paths.append((name, treat_path(main_path)))
return paths
def main(data: str):
print("Checking files for the pytest snippet...")
parsed_data = parse_json(data)
paths = get_paths_from_parsed_data(parsed_data)
bad_paths = []
for name, path in paths:
# Special case for myst doc checker
if "test_myst_doc" in str(path):
continue
print(f"Checking test '{name}' | file '{path}'...")
try:
with open(path, "r") as f:
if not check_file(f.read()):
print(f"File '{path}' is missing the pytest snippet.")
bad_paths.append(path)
except FileNotFoundError:
print(f"File '{path}' is missing.")
bad_paths.append((path, "path is missing!"))
if bad_paths:
formatted_bad_paths = "\n".join([str(x) for x in bad_paths])
raise RuntimeError(
'Found py_test files without `if __name__ == "__main__":` snippet:'
f"\n{formatted_bad_paths}\n"
"If this is intentional, please add a `no_main` tag to bazel BUILD "
"entry for those files."
)
if __name__ == "__main__":
# Expects a json
# Invocation from workspace root:
# bazel query 'kind(py_test.*, tests(python/...) intersect
# attr(tags, "\bteam:ml\b", python/...) except attr(tags, "\bno_main\b",
# python/...))' --output xml | xq | python scripts/pytest_checker.py
data = sys.stdin.read()
main(data)