ray/scripts/pytest_checker.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

127 lines
4.1 KiB
Python
Raw Normal View History

import json
import re
import sys
from pathlib import Path
def check_file(file_contents: str) -> bool:
"""Check file for the snippet"""
return bool(re.search(r"^if __name__ == \"__main__\":", file_contents, re.M))
def parse_json(data: str) -> dict:
return json.loads(data)
def treat_path(path: str) -> Path:
"""Treat bazel paths to filesystem paths"""
path = path[2:].replace(":", "/")
return Path(path)
def get_paths_from_parsed_data(parsed_data: dict) -> list:
# Example JSON input:
# "rule": [
# {
# "@class": "py_test",
# "@location": "/home/ubuntu/ray/python/ray/tests/BUILD:345:8",
# "@name": "//python/ray/tests:test_tracing",
# "string": [
# {
# "@name": "name",
# "@value": "test_tracing"
# },
# ],
# "list": [
# {
# "@name": "srcs",
# "label": [
# {
# "@value": "//python/ray/tests:aws/conftest.py"
# },
# {
# "@value": "//python/ray/tests:conftest.py"
# },
# {
# "@value": "//python/ray/tests:test_tracing.py"
# }
# ]
# }
# ],
# ... other fields ...
# "label": {
# "@name": "main",
# "@value": "//python/ray/tests:test_runtime_env_working_dir_remote_uri.py"
# },
# ... other fields ...
# }
# ]
#
# We want to get the location of the actual test file.
# This can be, in order of priority:
# 1. Specified as the "main" label
# 2. Specified as the ONLY "srcs" label
# 3. Specified as the "srcs" label matching the "name" of the test
# https://docs.bazel.build/versions/main/be/python.html#py_test
paths = []
for rule in parsed_data["query"]["rule"]:
name = rule["@name"]
if "label" in rule and rule["label"]["@name"] == "main":
paths.append((name, treat_path(rule["label"]["@value"])))
else:
list_args = {e["@name"]: e for e in rule["list"]}
label = list_args["srcs"]["label"]
if isinstance(label, dict):
paths.append((name, treat_path(label["@value"])))
else:
# list
string_name = next(
x["@value"] for x in rule["string"] if x["@name"] == "name"
)
main_path = next(
x["@value"] for x in label if string_name in x["@value"]
)
paths.append((name, treat_path(main_path)))
return paths
def main(data: str):
print("Checking files for the pytest snippet...")
parsed_data = parse_json(data)
paths = get_paths_from_parsed_data(parsed_data)
bad_paths = []
for name, path in paths:
# Special case for myst doc checker
if "test_myst_doc" in str(path):
continue
print(f"Checking test '{name}' | file '{path}'...")
try:
with open(path, "r") as f:
if not check_file(f.read()):
print(f"File '{path}' is missing the pytest snippet.")
bad_paths.append(path)
except FileNotFoundError:
print(f"File '{path}' is missing.")
bad_paths.append((path, "path is missing!"))
if bad_paths:
formatted_bad_paths = "\n".join([str(x) for x in bad_paths])
raise RuntimeError(
'Found py_test files without `if __name__ == "__main__":` snippet:'
f"\n{formatted_bad_paths}\n"
"If this is intentional, please add a `no_main` tag to bazel BUILD "
"entry for those files."
)
if __name__ == "__main__":
# Expects a json
# Invocation from workspace root:
# bazel query 'kind(py_test.*, tests(python/...) intersect
# attr(tags, "\bteam:ml\b", python/...) except attr(tags, "\bno_main\b",
# python/...))' --output xml | xq | python scripts/pytest_checker.py
data = sys.stdin.read()
main(data)