ray/ci/travis/bazel-sharding.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

114 lines
3.6 KiB
Python

#!/usr/bin/env python3
#
# Copyright 2021 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# FROM https://github.com/philwo/bazel-utils/blob/main/sharding/sharding.py
import argparse
import os
import subprocess
import sys
def partition_targets(targets):
included_targets, excluded_targets = [], []
for target in targets:
if target.startswith("-"):
excluded_targets.append(target[1:])
else:
included_targets.append(target)
return included_targets, excluded_targets
def quote_targets(targets):
return (" ".join("'{}'".format(t) for t in targets)) if targets else ""
def get_target_expansion_query(targets, tests_only, exclude_manual):
included_targets, excluded_targets = partition_targets(targets)
included_targets = quote_targets(included_targets)
excluded_targets = quote_targets(excluded_targets)
query = "set({})".format(included_targets)
if tests_only:
query = "tests({})".format(query)
if excluded_targets:
excluded_set = "set({})".format(excluded_targets)
if tests_only:
excluded_set = "tests({})".format(excluded_set)
query = "{} except {}".format(query, excluded_set)
if exclude_manual:
query = '{} except tests(attr("tags", "manual", set({})))'.format(
query, included_targets
)
return query
def run_bazel_query(query, debug):
args = ["bazel", "query", query]
if debug:
print("$ {}".format(" ".join(args)), file=sys.stderr)
sys.stderr.flush()
p = subprocess.run(
["bazel", "query", query],
check=True,
stdout=subprocess.PIPE,
errors="replace",
universal_newlines=True,
)
output = p.stdout.strip()
return output.splitlines() if output else []
def get_targets_for_shard(targets, index, count):
# This is a very simple way of sharding targets. A more sophisticated
# approach might want to take test sizes into account, for example.
return sorted(targets)[index::count]
def main():
parser = argparse.ArgumentParser(description="Expand and shard Bazel targets.")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--tests_only", action="store_true")
parser.add_argument("--exclude_manual", action="store_true")
parser.add_argument(
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)
)
parser.add_argument(
"--count", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1)
)
parser.add_argument("targets", nargs="+")
args, extra_args = parser.parse_known_args()
args.targets = list(args.targets) + list(extra_args)
if args.index >= args.count:
parser.error("--index must be between 0 and {}".format(args.count - 1))
query = get_target_expansion_query(
args.targets, args.tests_only, args.exclude_manual
)
expanded_targets = run_bazel_query(query, args.debug)
my_targets = get_targets_for_shard(expanded_targets, args.index, args.count)
print(" ".join(my_targets))
return 0
if __name__ == "__main__":
sys.exit(main())