ray/doc/source/train/doc_code/xgboost_train_predict.py
Richard Liaw 4629a3a649
[air/docs] Update Trainer documentation (#27481)
Co-authored-by: xwjiang2010 <xwjiang2010@gmail.com>
Co-authored-by: Kai Fricke <kai@anyscale.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Eric Liang <ekhliang@gmail.com>
2022-08-05 11:21:19 -07:00

33 lines
980 B
Python

# flake8: noqa
# isort: skip_file
# __train_predict_start__
import numpy as np
import ray
from ray.train.xgboost import XGBoostTrainer, XGBoostPredictor
from ray.air.config import ScalingConfig
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = XGBoostTrainer(
label_column="y",
params={"objective": "reg:squarederror"},
scaling_config=ScalingConfig(num_workers=3),
datasets={"train": train_dataset},
)
result = trainer.fit()
predictor = XGBoostPredictor.from_checkpoint(result.checkpoint)
predictions = predictor.predict(np.expand_dims(np.arange(32, 64), 1))
# __train_predict_end__
# __batch_predict_start__
from ray.train.batch_predictor import BatchPredictor
batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, XGBoostPredictor)
predictions = batch_predictor.predict(
data=ray.data.from_items([{"x": x} for x in range(32)]),
batch_size=8,
min_scoring_workers=2,
)
# __batch_predict_end__