mirror of
https://github.com/vale981/ray
synced 2025-03-09 21:06:39 -04:00
14 lines
385 B
Python
14 lines
385 B
Python
![]() |
import ray
|
||
|
|
||
|
from ray.train.xgboost import XGBoostTrainer
|
||
|
from ray.air.config import ScalingConfig
|
||
|
|
||
|
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
|
||
|
trainer = XGBoostTrainer(
|
||
|
label_column="y",
|
||
|
params={"objective": "reg:squarederror"},
|
||
|
scaling_config=ScalingConfig(num_workers=3),
|
||
|
datasets={"train": train_dataset},
|
||
|
)
|
||
|
result = trainer.fit()
|