2022-07-15 13:21:10 -07:00
|
|
|
from functools import wraps
|
|
|
|
import json
|
|
|
|
from multiprocessing import Process
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
import xgboost as xgb
|
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray import data
|
|
|
|
from ray.train.xgboost import (
|
|
|
|
XGBoostTrainer,
|
2022-07-20 19:33:27 -07:00
|
|
|
XGBoostCheckpoint,
|
2022-07-15 13:21:10 -07:00
|
|
|
XGBoostPredictor,
|
|
|
|
)
|
|
|
|
from ray.train.batch_predictor import BatchPredictor
|
2022-07-18 18:46:58 -04:00
|
|
|
from ray.air.config import ScalingConfig
|
2022-07-15 13:21:10 -07:00
|
|
|
|
|
|
|
_XGB_MODEL_PATH = "model.json"
|
|
|
|
_TRAINING_TIME_THRESHOLD = 1000
|
|
|
|
_PREDICTION_TIME_THRESHOLD = 450
|
|
|
|
|
2022-07-15 15:33:48 -07:00
|
|
|
_EXPERIMENT_PARAMS = {
|
|
|
|
"10G": {
|
|
|
|
"data": "s3://air-example-data-2/10G-xgboost-data.parquet/",
|
|
|
|
"num_workers": 1,
|
|
|
|
},
|
|
|
|
"100G": {
|
|
|
|
"data": "s3://air-example-data-2/100G-xgboost-data.parquet/",
|
|
|
|
"num_workers": 10,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
2022-07-15 13:21:10 -07:00
|
|
|
|
|
|
|
def run_and_time_it(f):
|
|
|
|
"""Runs f in a separate process and time it."""
|
|
|
|
|
|
|
|
@wraps(f)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
p = Process(target=f, args=args)
|
|
|
|
start = time.monotonic()
|
|
|
|
p.start()
|
|
|
|
p.join()
|
|
|
|
time_taken = time.monotonic() - start
|
|
|
|
print(f"{f.__name__} takes {time_taken} seconds.")
|
|
|
|
return time_taken
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
@run_and_time_it
|
2022-07-15 15:33:48 -07:00
|
|
|
def run_xgboost_training(data_path: str, num_workers: int):
|
|
|
|
ds = data.read_parquet(data_path)
|
2022-07-15 13:21:10 -07:00
|
|
|
params = {
|
|
|
|
"objective": "binary:logistic",
|
|
|
|
"eval_metric": ["logloss", "error"],
|
|
|
|
}
|
|
|
|
|
|
|
|
trainer = XGBoostTrainer(
|
2022-07-18 18:46:58 -04:00
|
|
|
scaling_config=ScalingConfig(
|
|
|
|
num_workers=num_workers,
|
|
|
|
resources_per_worker={"CPU": 12},
|
|
|
|
),
|
2022-07-15 13:21:10 -07:00
|
|
|
label_column="labels",
|
|
|
|
params=params,
|
|
|
|
datasets={"train": ds},
|
|
|
|
)
|
|
|
|
result = trainer.fit()
|
2022-07-20 19:33:27 -07:00
|
|
|
checkpoint = XGBoostCheckpoint.from_checkpoint(result.checkpoint)
|
|
|
|
xgboost_model = checkpoint.get_model()
|
2022-07-15 13:21:10 -07:00
|
|
|
xgboost_model.save_model(_XGB_MODEL_PATH)
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
|
|
|
|
@run_and_time_it
|
2022-07-15 15:33:48 -07:00
|
|
|
def run_xgboost_prediction(model_path: str, data_path: str):
|
2022-07-15 13:21:10 -07:00
|
|
|
model = xgb.Booster()
|
|
|
|
model.load_model(model_path)
|
2022-07-15 15:33:48 -07:00
|
|
|
ds = data.read_parquet(data_path)
|
2022-07-20 19:33:27 -07:00
|
|
|
ckpt = XGBoostCheckpoint.from_model(".", model)
|
2022-07-15 13:21:10 -07:00
|
|
|
batch_predictor = BatchPredictor.from_checkpoint(ckpt, XGBoostPredictor)
|
|
|
|
result = batch_predictor.predict(ds.drop_columns(["labels"]))
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2022-07-15 15:33:48 -07:00
|
|
|
def main(args):
|
|
|
|
experiment_params = _EXPERIMENT_PARAMS[args.size]
|
|
|
|
data_path, num_workers = experiment_params["data"], experiment_params["num_workers"]
|
2022-07-15 13:21:10 -07:00
|
|
|
print("Running xgboost training benchmark...")
|
2022-07-15 15:33:48 -07:00
|
|
|
training_time = run_xgboost_training(data_path, num_workers)
|
2022-07-15 13:21:10 -07:00
|
|
|
print("Running xgboost prediction benchmark...")
|
2022-07-15 15:33:48 -07:00
|
|
|
prediction_time = run_xgboost_prediction(_XGB_MODEL_PATH, data_path)
|
2022-07-15 13:21:10 -07:00
|
|
|
result = {
|
|
|
|
"training_time": training_time,
|
|
|
|
"prediction_time": prediction_time,
|
|
|
|
}
|
|
|
|
print("Results:", result)
|
|
|
|
test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/result.json")
|
|
|
|
with open(test_output_json, "wt") as f:
|
|
|
|
json.dump(result, f)
|
|
|
|
|
|
|
|
if training_time > _TRAINING_TIME_THRESHOLD:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Training on XGBoost is taking {training_time} seconds, "
|
|
|
|
f"which is longer than expected ({_TRAINING_TIME_THRESHOLD} seconds)."
|
|
|
|
)
|
|
|
|
|
|
|
|
if prediction_time > _PREDICTION_TIME_THRESHOLD:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Batch prediction on XGBoost is taking {prediction_time} seconds, "
|
|
|
|
f"which is longer than expected ({_PREDICTION_TIME_THRESHOLD} seconds)."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-07-15 15:33:48 -07:00
|
|
|
import argparse
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--size", type=str, choices=["10G", "100G"], default="100G")
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|