ray/release/air_tests/air_benchmarks/workloads/xgboost_benchmark.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

103 lines
2.8 KiB
Python
Raw Normal View History

from functools import wraps
import json
from multiprocessing import Process
import os
import time
import xgboost as xgb
import ray
from ray import data
from ray.train.xgboost import (
XGBoostTrainer,
load_checkpoint,
to_air_checkpoint,
XGBoostPredictor,
)
from ray.train.batch_predictor import BatchPredictor
_XGB_MODEL_PATH = "model.json"
_TRAINING_TIME_THRESHOLD = 1000
_PREDICTION_TIME_THRESHOLD = 450
def run_and_time_it(f):
"""Runs f in a separate process and time it."""
@wraps(f)
def wrapper(*args, **kwargs):
p = Process(target=f, args=args)
start = time.monotonic()
p.start()
p.join()
time_taken = time.monotonic() - start
print(f"{f.__name__} takes {time_taken} seconds.")
return time_taken
return wrapper
@run_and_time_it
def run_xgboost_training():
ds = data.read_parquet(
"s3://air-example-data-2/100G-xgboost-data.parquet/"
) # silver tier
params = {
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
}
trainer = XGBoostTrainer(
scaling_config={"num_workers": 10, "resources_per_worker": {"CPU": 12}},
label_column="labels",
params=params,
datasets={"train": ds},
)
result = trainer.fit()
xgboost_model = load_checkpoint(result.checkpoint)[0]
xgboost_model.save_model(_XGB_MODEL_PATH)
ray.shutdown()
@run_and_time_it
def run_xgboost_prediction(model_path: str):
model = xgb.Booster()
model.load_model(model_path)
ds = data.read_parquet(
"s3://air-example-data-2/100G-xgboost-data.parquet/"
) # silver tier
ckpt = to_air_checkpoint(".", model)
batch_predictor = BatchPredictor.from_checkpoint(ckpt, XGBoostPredictor)
result = batch_predictor.predict(ds.drop_columns(["labels"]))
return result
def main():
print("Running xgboost training benchmark...")
training_time = run_xgboost_training()
print("Running xgboost prediction benchmark...")
prediction_time = run_xgboost_prediction(_XGB_MODEL_PATH)
result = {
"training_time": training_time,
"prediction_time": prediction_time,
}
print("Results:", result)
test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/result.json")
with open(test_output_json, "wt") as f:
json.dump(result, f)
if training_time > _TRAINING_TIME_THRESHOLD:
raise RuntimeError(
f"Training on XGBoost is taking {training_time} seconds, "
f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
)
if prediction_time > _PREDICTION_TIME_THRESHOLD:
raise RuntimeError(
f"Batch prediction on XGBoost is taking {prediction_time} seconds, "
f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
)
if __name__ == "__main__":
main()