ray/doc/source/ray-air/doc_code/sklearn_trainer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

14 lines
418 B
Python
Raw Normal View History

import ray
from ray.train.sklearn import SklearnTrainer
from sklearn.ensemble import RandomForestRegressor
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = SklearnTrainer(
estimator=RandomForestRegressor(),
label_column="y",
scaling_config=ray.air.config.ScalingConfig(trainer_resources={"CPU": 4}),
datasets={"train": train_dataset},
)
result = trainer.fit()