mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2021 Google Inc. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# FROM https://github.com/philwo/bazel-utils/blob/main/sharding/sharding.py
|
|
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
|
|
|
|
def partition_targets(targets):
|
|
included_targets, excluded_targets = [], []
|
|
for target in targets:
|
|
if target.startswith("-"):
|
|
excluded_targets.append(target[1:])
|
|
else:
|
|
included_targets.append(target)
|
|
return included_targets, excluded_targets
|
|
|
|
|
|
def quote_targets(targets):
|
|
return (" ".join("'{}'".format(t) for t in targets)) if targets else ""
|
|
|
|
|
|
def get_target_expansion_query(targets, tests_only, exclude_manual):
|
|
included_targets, excluded_targets = partition_targets(targets)
|
|
|
|
included_targets = quote_targets(included_targets)
|
|
excluded_targets = quote_targets(excluded_targets)
|
|
|
|
query = "set({})".format(included_targets)
|
|
if tests_only:
|
|
query = "tests({})".format(query)
|
|
|
|
if excluded_targets:
|
|
excluded_set = "set({})".format(excluded_targets)
|
|
if tests_only:
|
|
excluded_set = "tests({})".format(excluded_set)
|
|
query = "{} except {}".format(query, excluded_set)
|
|
|
|
if exclude_manual:
|
|
query = '{} except tests(attr("tags", "manual", set({})))'.format(
|
|
query, included_targets
|
|
)
|
|
|
|
return query
|
|
|
|
|
|
def run_bazel_query(query, debug):
|
|
args = ["bazel", "query", query]
|
|
if debug:
|
|
print("$ {}".format(" ".join(args)), file=sys.stderr)
|
|
sys.stderr.flush()
|
|
p = subprocess.run(
|
|
["bazel", "query", query],
|
|
check=True,
|
|
stdout=subprocess.PIPE,
|
|
errors="replace",
|
|
universal_newlines=True,
|
|
)
|
|
output = p.stdout.strip()
|
|
return output.splitlines() if output else []
|
|
|
|
|
|
def get_targets_for_shard(targets, index, count):
|
|
# This is a very simple way of sharding targets. A more sophisticated
|
|
# approach might want to take test sizes into account, for example.
|
|
return sorted(targets)[index::count]
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Expand and shard Bazel targets.")
|
|
parser.add_argument("--debug", action="store_true")
|
|
parser.add_argument("--tests_only", action="store_true")
|
|
parser.add_argument("--exclude_manual", action="store_true")
|
|
parser.add_argument(
|
|
"--index", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB", 1)
|
|
)
|
|
parser.add_argument(
|
|
"--count", type=int, default=os.getenv("BUILDKITE_PARALLEL_JOB_COUNT", 1)
|
|
)
|
|
parser.add_argument("targets", nargs="+")
|
|
args, extra_args = parser.parse_known_args()
|
|
args.targets = list(args.targets) + list(extra_args)
|
|
|
|
if args.index >= args.count:
|
|
parser.error("--index must be between 0 and {}".format(args.count - 1))
|
|
|
|
query = get_target_expansion_query(
|
|
args.targets, args.tests_only, args.exclude_manual
|
|
)
|
|
expanded_targets = run_bazel_query(query, args.debug)
|
|
my_targets = get_targets_for_shard(expanded_targets, args.index, args.count)
|
|
print(" ".join(my_targets))
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|