mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

Co-authored-by: xwjiang2010 <xwjiang2010@gmail.com> Co-authored-by: Kai Fricke <kai@anyscale.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Eric Liang <ekhliang@gmail.com>
33 lines
980 B
Python
33 lines
980 B
Python
# flake8: noqa
|
|
# isort: skip_file
|
|
|
|
# __train_predict_start__
|
|
import numpy as np
|
|
import ray
|
|
|
|
from ray.train.xgboost import XGBoostTrainer, XGBoostPredictor
|
|
from ray.air.config import ScalingConfig
|
|
|
|
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
|
|
trainer = XGBoostTrainer(
|
|
label_column="y",
|
|
params={"objective": "reg:squarederror"},
|
|
scaling_config=ScalingConfig(num_workers=3),
|
|
datasets={"train": train_dataset},
|
|
)
|
|
result = trainer.fit()
|
|
|
|
predictor = XGBoostPredictor.from_checkpoint(result.checkpoint)
|
|
predictions = predictor.predict(np.expand_dims(np.arange(32, 64), 1))
|
|
# __train_predict_end__
|
|
|
|
# __batch_predict_start__
|
|
from ray.train.batch_predictor import BatchPredictor
|
|
|
|
batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, XGBoostPredictor)
|
|
predictions = batch_predictor.predict(
|
|
data=ray.data.from_items([{"x": x} for x in range(32)]),
|
|
batch_size=8,
|
|
min_scoring_workers=2,
|
|
)
|
|
# __batch_predict_end__
|