{ "cells": [ { "cell_type": "markdown", "id": "5fb89b3d", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Training a model with distributed XGBoost\n", "In this example we will train a model in Ray AIR using distributed XGBoost." ] }, { "cell_type": "markdown", "id": "53d57c1f", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Let's start with installing our dependencies:" ] }, { "cell_type": "code", "execution_count": null, "id": "41f20cc1", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "!pip install -qU \"ray[tune]\" xgboost_ray" ] }, { "cell_type": "markdown", "id": "d2fe8d4a", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Then we need some imports:" ] }, { "cell_type": "code", "execution_count": 1, "id": "7232303d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", " from pandas import MultiIndex, Int64Index\n", "FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n" ] } ], "source": [ "from typing import Tuple\n", "\n", "import ray\n", "from ray.train.batch_predictor import BatchPredictor\n", "from ray.train.xgboost import XGBoostPredictor\n", "from ray.train.xgboost import XGBoostTrainer\n", "from ray.air.config import ScalingConfig\n", "from ray.data.dataset import Dataset\n", "from ray.air.result import Result\n", "from ray.data.preprocessors import StandardScaler" ] }, { "cell_type": "markdown", "id": "1c75b5ca", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Next we define a function to load our train, validation, and test datasets." ] }, { "cell_type": "code", "execution_count": 2, "id": "37c4f38f", "metadata": {}, "outputs": [], "source": [ "def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:\n", " dataset = ray.data.read_csv(\"s3://anonymous@air-example-data/breast_cancer.csv\")\n", " train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)\n", " test_dataset = valid_dataset.drop_columns([\"target\"])\n", " return train_dataset, valid_dataset, test_dataset" ] }, { "cell_type": "markdown", "id": "9b2850dd", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "The following function will create a XGBoost trainer, train it, and return the result." ] }, { "cell_type": "code", "execution_count": 3, "id": "dae8998d", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def train_xgboost(num_workers: int, use_gpu: bool = False) -> Result:\n", " train_dataset, valid_dataset, _ = prepare_data()\n", "\n", " # Scale some random columns\n", " columns_to_scale = [\"mean radius\", \"mean texture\"]\n", " preprocessor = StandardScaler(columns=columns_to_scale)\n", "\n", " # XGBoost specific params\n", " params = {\n", " \"tree_method\": \"approx\",\n", " \"objective\": \"binary:logistic\",\n", " \"eval_metric\": [\"logloss\", \"error\"],\n", " }\n", "\n", " trainer = XGBoostTrainer(\n", " scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),\n", " label_column=\"target\",\n", " params=params,\n", " datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n", " preprocessor=preprocessor,\n", " num_boost_round=100,\n", " )\n", " result = trainer.fit()\n", " print(result.metrics)\n", "\n", " return result" ] }, { "cell_type": "markdown", "id": "ce05af87", "metadata": {}, "source": [ "Once we have the result, we can do batch inference on the obtained model. Let's define a utility function for this." ] }, { "cell_type": "code", "execution_count": 4, "id": "5b8076d3", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def predict_xgboost(result: Result):\n", " _, _, test_dataset = prepare_data()\n", "\n", " batch_predictor = BatchPredictor.from_checkpoint(\n", " result.checkpoint, XGBoostPredictor\n", " )\n", "\n", " predicted_labels = (\n", " batch_predictor.predict(test_dataset)\n", " .map_batches(lambda df: (df > 0.5).astype(int), batch_format=\"pandas\")\n", " )\n", " print(f\"PREDICTED LABELS\")\n", " predicted_labels.show()\n", "\n", " shap_values = batch_predictor.predict(test_dataset, pred_contribs=True)\n", " print(f\"SHAP VALUES\")\n", " shap_values.show()\n" ] }, { "cell_type": "markdown", "id": "7e172f66", "metadata": {}, "source": [ "Now we can run the training:" ] }, { "cell_type": "code", "execution_count": 5, "id": "0f96d62b", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-06-22 17:28:55,841\tINFO services.py:1477 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8270\u001b[39m\u001b[22m\n", "2022-06-22 17:28:58,044\tWARNING read_api.py:260 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n", "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 40.28it/s]\n" ] }, { "data": { "text/html": [ "== Status ==
Current time: 2022-06-22 17:29:15 (running for 00:00:16.11)
Memory usage on this node: 11.5/31.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/12.35 GiB heap, 0.0/6.18 GiB objects
Result logdir: /home/ubuntu/ray_results/XGBoostTrainer_2022-06-22_17-28-58
Number of trials: 1/1 (1 TERMINATED)
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) train-logloss train-error valid-logloss
XGBoostTrainer_cc863_00000TERMINATED172.31.43.110:1493910 100 12.5164 0.005874 0 0.078188


" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(pid=1493910)\u001b[0m /home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1493910)\u001b[0m from pandas import MultiIndex, Int64Index\n", "\u001b[2m\u001b[36m(pid=1493910)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1493910)\u001b[0m FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1493910)\u001b[0m FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(XGBoostTrainer pid=1493910)\u001b[0m UserWarning: Dataset 'train' has 1 blocks, which is less than the `num_workers` 2. This dataset will be automatically repartitioned to 2 blocks.\n", "\u001b[2m\u001b[36m(XGBoostTrainer pid=1493910)\u001b[0m UserWarning: Dataset 'valid' has 1 blocks, which is less than the `num_workers` 2. This dataset will be automatically repartitioned to 2 blocks.\n", "\u001b[2m\u001b[36m(XGBoostTrainer pid=1493910)\u001b[0m 2022-06-22 17:29:04,073\tINFO main.py:980 -- [RayXGBoost] Created 2 new actors (2 total actors). Waiting until actors are ready for training.\n", "\u001b[2m\u001b[36m(pid=1494007)\u001b[0m /home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494007)\u001b[0m from pandas import MultiIndex, Int64Index\n", "\u001b[2m\u001b[36m(pid=1494008)\u001b[0m /home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494008)\u001b[0m from pandas import MultiIndex, Int64Index\n", "\u001b[2m\u001b[36m(pid=1494009)\u001b[0m /home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494009)\u001b[0m from pandas import MultiIndex, Int64Index\n", "\u001b[2m\u001b[36m(pid=1494007)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494007)\u001b[0m FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494007)\u001b[0m FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494008)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494008)\u001b[0m FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494008)\u001b[0m FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_RemoteRayXGBoostActor pid=1494008)\u001b[0m 2022-06-22 17:29:07,324\tWARNING __init__.py:190 -- DeprecationWarning: `ray.worker.get_resource_ids` is a private attribute and access will be removed in a future Ray version.\n", "\u001b[2m\u001b[36m(pid=1494009)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494009)\u001b[0m FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(pid=1494009)\u001b[0m FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_RemoteRayXGBoostActor pid=1494009)\u001b[0m 2022-06-22 17:29:07,421\tWARNING __init__.py:190 -- DeprecationWarning: `ray.worker.get_resource_ids` is a private attribute and access will be removed in a future Ray version.\n", "\u001b[2m\u001b[36m(XGBoostTrainer pid=1493910)\u001b[0m 2022-06-22 17:29:07,874\tINFO main.py:1025 -- [RayXGBoost] Starting XGBoost training.\n", "\u001b[2m\u001b[36m(_RemoteRayXGBoostActor pid=1494008)\u001b[0m [17:29:07] task [xgboost.ray]:139731353900128 got new rank 0\n", "\u001b[2m\u001b[36m(_RemoteRayXGBoostActor pid=1494008)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_RemoteRayXGBoostActor pid=1494009)\u001b[0m [17:29:07] task [xgboost.ray]:140076138558608 got new rank 1\n", "\u001b[2m\u001b[36m(_RemoteRayXGBoostActor pid=1494009)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_QueueActor pid=1494006)\u001b[0m /home/ubuntu/ray/venv/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_QueueActor pid=1494006)\u001b[0m from pandas import MultiIndex, Int64Index\n", "\u001b[2m\u001b[36m(_QueueActor pid=1494006)\u001b[0m FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_QueueActor pid=1494006)\u001b[0m FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "\u001b[2m\u001b[36m(_QueueActor pid=1494006)\u001b[0m FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Result for XGBoostTrainer_cc863_00000:\n", " date: 2022-06-22_17-29-09\n", " done: false\n", " experiment_id: dc3dac01a34043cfb5751907e2bc648e\n", " hostname: ip-172-31-43-110\n", " iterations_since_restore: 1\n", " node_ip: 172.31.43.110\n", " pid: 1493910\n", " should_checkpoint: true\n", " time_since_restore: 7.967940330505371\n", " time_this_iter_s: 7.967940330505371\n", " time_total_s: 7.967940330505371\n", " timestamp: 1655918949\n", " timesteps_since_restore: 0\n", " train-error: 0.017588\n", " train-logloss: 0.464648\n", " training_iteration: 1\n", " trial_id: cc863_00000\n", " valid-error: 0.081871\n", " valid-logloss: 0.496374\n", " warmup_time: 0.004768848419189453\n", " \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(XGBoostTrainer pid=1493910)\u001b[0m 2022-06-22 17:29:14,546\tINFO main.py:1516 -- [RayXGBoost] Finished XGBoost training on training data with total N=398 in 10.52 seconds (6.66 pure XGBoost training time).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Result for XGBoostTrainer_cc863_00000:\n", " date: 2022-06-22_17-29-14\n", " done: true\n", " experiment_id: dc3dac01a34043cfb5751907e2bc648e\n", " experiment_tag: '0'\n", " hostname: ip-172-31-43-110\n", " iterations_since_restore: 100\n", " node_ip: 172.31.43.110\n", " pid: 1493910\n", " should_checkpoint: true\n", " time_since_restore: 12.516392230987549\n", " time_this_iter_s: 0.03008890151977539\n", " time_total_s: 12.516392230987549\n", " timestamp: 1655918954\n", " timesteps_since_restore: 0\n", " train-error: 0.0\n", " train-logloss: 0.005874\n", " training_iteration: 100\n", " trial_id: cc863_00000\n", " valid-error: 0.040936\n", " valid-logloss: 0.078188\n", " warmup_time: 0.004768848419189453\n", " \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-06-22 17:29:15,362\tINFO tune.py:734 -- Total run time: 16.94 seconds (16.08 seconds for the tuning loop).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'train-logloss': 0.005874, 'train-error': 0.0, 'valid-logloss': 0.078188, 'valid-error': 0.040936, 'time_this_iter_s': 0.03008890151977539, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 100, 'trial_id': 'cc863_00000', 'experiment_id': 'dc3dac01a34043cfb5751907e2bc648e', 'date': '2022-06-22_17-29-14', 'timestamp': 1655918954, 'time_total_s': 12.516392230987549, 'pid': 1493910, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 12.516392230987549, 'timesteps_since_restore': 0, 'iterations_since_restore': 100, 'warmup_time': 0.004768848419189453, 'experiment_tag': '0'}\n" ] } ], "source": [ "result = train_xgboost(num_workers=2, use_gpu=False)" ] }, { "cell_type": "markdown", "id": "7055ad1b", "metadata": {}, "source": [ "And perform inference on the obtained model:" ] }, { "cell_type": "code", "execution_count": 6, "id": "283b1dba", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-06-22 17:29:16,463\tWARNING read_api.py:260 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n", "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 46.14it/s]\n", "Map_Batches: 0%| | 0/1 [00:00