{
"cells": [
{
"cell_type": "markdown",
"id": "5fb89b3d",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Training a model with distributed XGBoost\n",
"In this example we will train a model in Ray AIR using distributed XGBoost."
]
},
{
"cell_type": "markdown",
"id": "53d57c1f",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let's start with installing our dependencies:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41f20cc1",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"!pip install -qU \"ray[tune]\" xgboost_ray"
]
},
{
"cell_type": "markdown",
"id": "d2fe8d4a",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Then we need some imports:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7232303d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n",
" from pandas import MultiIndex, Int64Index\n",
"FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n",
"FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n",
"FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n"
]
}
],
"source": [
"from typing import Tuple\n",
"\n",
"import ray\n",
"from ray.train.batch_predictor import BatchPredictor\n",
"from ray.train.xgboost import XGBoostPredictor\n",
"from ray.train.xgboost import XGBoostTrainer\n",
"from ray.air.config import ScalingConfig\n",
"from ray.data.dataset import Dataset\n",
"from ray.air.result import Result\n",
"from ray.data.preprocessors import StandardScaler"
]
},
{
"cell_type": "markdown",
"id": "1c75b5ca",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Next we define a function to load our train, validation, and test datasets."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "37c4f38f",
"metadata": {},
"outputs": [],
"source": [
"def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:\n",
" dataset = ray.data.read_csv(\"s3://anonymous@air-example-data/breast_cancer.csv\")\n",
" train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)\n",
" test_dataset = valid_dataset.drop_columns([\"target\"])\n",
" return train_dataset, valid_dataset, test_dataset"
]
},
{
"cell_type": "markdown",
"id": "9b2850dd",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The following function will create a XGBoost trainer, train it, and return the result."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dae8998d",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def train_xgboost(num_workers: int, use_gpu: bool = False) -> Result:\n",
" train_dataset, valid_dataset, _ = prepare_data()\n",
"\n",
" # Scale some random columns\n",
" columns_to_scale = [\"mean radius\", \"mean texture\"]\n",
" preprocessor = StandardScaler(columns=columns_to_scale)\n",
"\n",
" # XGBoost specific params\n",
" params = {\n",
" \"tree_method\": \"approx\",\n",
" \"objective\": \"binary:logistic\",\n",
" \"eval_metric\": [\"logloss\", \"error\"],\n",
" }\n",
"\n",
" trainer = XGBoostTrainer(\n",
" scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),\n",
" label_column=\"target\",\n",
" params=params,\n",
" datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n",
" preprocessor=preprocessor,\n",
" num_boost_round=100,\n",
" )\n",
" result = trainer.fit()\n",
" print(result.metrics)\n",
"\n",
" return result"
]
},
{
"cell_type": "markdown",
"id": "ce05af87",
"metadata": {},
"source": [
"Once we have the result, we can do batch inference on the obtained model. Let's define a utility function for this."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5b8076d3",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def predict_xgboost(result: Result):\n",
" _, _, test_dataset = prepare_data()\n",
"\n",
" batch_predictor = BatchPredictor.from_checkpoint(\n",
" result.checkpoint, XGBoostPredictor\n",
" )\n",
"\n",
" predicted_labels = (\n",
" batch_predictor.predict(test_dataset)\n",
" .map_batches(lambda df: (df > 0.5).astype(int), batch_format=\"pandas\")\n",
" )\n",
" print(f\"PREDICTED LABELS\")\n",
" predicted_labels.show()\n",
"\n",
" shap_values = batch_predictor.predict(test_dataset, pred_contribs=True)\n",
" print(f\"SHAP VALUES\")\n",
" shap_values.show()\n"
]
},
{
"cell_type": "markdown",
"id": "7e172f66",
"metadata": {},
"source": [
"Now we can run the training:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0f96d62b",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-06-22 17:28:55,841\tINFO services.py:1477 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8270\u001b[39m\u001b[22m\n",
"2022-06-22 17:28:58,044\tWARNING read_api.py:260 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n",
"Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 40.28it/s]\n"
]
},
{
"data": {
"text/html": [
"== Status ==
Current time: 2022-06-22 17:29:15 (running for 00:00:16.11)
Memory usage on this node: 11.5/31.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/12.35 GiB heap, 0.0/6.18 GiB objects
Result logdir: /home/ubuntu/ray_results/XGBoostTrainer_2022-06-22_17-28-58
Number of trials: 1/1 (1 TERMINATED)
Trial name | status | loc | iter | total time (s) | train-logloss | train-error | valid-logloss |
---|---|---|---|---|---|---|---|
XGBoostTrainer_cc863_00000 | TERMINATED | 172.31.43.110:1493910 | 100 | 12.5164 | 0.005874 | 0 | 0.078188 |