mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[docs][ml][kuberay] Add a --disable-check flag to the XGBoost benchmark. (#27277)
This PR adds a flag --disable-check to the XGBoost benchmark script which disables the RuntimeError that comes up if training or prediction took too long. This is meant for non-CI exploratory use-cases. Specifically, the reason is this: We will include the XGBoost benchmark as an example workload for the KubeRay documentation. The actual performance of the workload is highly sensitive to infrastructure environment, so we won't want to raise an alarming RuntimeError if the workload took too long on the user's infrastructure. (When I tried the 100Gb benchmark on KubeRay, training ran just a couple of minutes longer than the 1000 second cutoff.)
This commit is contained in:
parent
1a10b53a61
commit
8bdeb30510
1 changed files with 19 additions and 10 deletions
|
@ -122,17 +122,18 @@ def main(args):
|
|||
with open(test_output_json, "wt") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
if training_time > _TRAINING_TIME_THRESHOLD:
|
||||
raise RuntimeError(
|
||||
f"Training on XGBoost is taking {training_time} seconds, "
|
||||
f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
|
||||
)
|
||||
if not args.disable_check:
|
||||
if training_time > _TRAINING_TIME_THRESHOLD:
|
||||
raise RuntimeError(
|
||||
f"Training on XGBoost is taking {training_time} seconds, "
|
||||
f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
|
||||
)
|
||||
|
||||
if prediction_time > _PREDICTION_TIME_THRESHOLD:
|
||||
raise RuntimeError(
|
||||
f"Batch prediction on XGBoost is taking {prediction_time} seconds, "
|
||||
f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
|
||||
)
|
||||
if prediction_time > _PREDICTION_TIME_THRESHOLD:
|
||||
raise RuntimeError(
|
||||
f"Batch prediction on XGBoost is taking {prediction_time} seconds, "
|
||||
f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -140,5 +141,13 @@ if __name__ == "__main__":
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--size", type=str, choices=["10G", "100G"], default="100G")
|
||||
# Add a flag for disabling the timeout error.
|
||||
# Use case: running the benchmark as a documented example, in infra settings
|
||||
# different from the formal benchmark's EC2 setup.
|
||||
parser.add_argument(
|
||||
"--disable-check",
|
||||
action="store_true",
|
||||
help="disable runtime error on benchmark timeout",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
Loading…
Add table
Reference in a new issue