{ "cells": [ { "cell_type": "markdown", "id": "c3192ac4", "metadata": {}, "source": [ "# Training a model with Sklearn\n", "In this example we will train a model in Ray AIR using a Sklearn classifier." ] }, { "cell_type": "markdown", "id": "5a4823bf", "metadata": {}, "source": [ "Let's start with installing our dependencies:" ] }, { "cell_type": "code", "execution_count": null, "id": "88f4bb39", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "!pip install -qU \"ray[tune]\" sklearn" ] }, { "cell_type": "markdown", "id": "c049c692", "metadata": {}, "source": [ "Then we need some imports:" ] }, { "cell_type": "code", "execution_count": 9, "id": "c02eb5cd", "metadata": {}, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "import ray\n", "from ray.data.dataset import Dataset\n", "from ray.train.batch_predictor import BatchPredictor\n", "from ray.train.sklearn import SklearnPredictor\n", "from ray.data.preprocessors import Chain, OrdinalEncoder, StandardScaler\n", "from ray.air.result import Result\n", "from ray.air.util.datasets import train_test_split\n", "from ray.train.sklearn import SklearnTrainer\n", "\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "try:\n", " from cuml.ensemble import RandomForestClassifier as cuMLRandomForestClassifier\n", "except ImportError:\n", " cuMLRandomForestClassifier = None" ] }, { "cell_type": "markdown", "id": "52e017f1", "metadata": {}, "source": [ "Next we define a function to load our train, validation, and test datasets." ] }, { "cell_type": "code", "execution_count": 10, "id": "3631ed1e", "metadata": {}, "outputs": [], "source": [ "def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:\n", " dataset = ray.data.read_csv(\"s3://air-example-data/breast_cancer_with_categorical.csv\")\n", " train_dataset, valid_dataset = train_test_split(dataset, test_size=0.3)\n", " test_dataset = valid_dataset.map_batches(lambda df: df.drop(\"target\", axis=1), batch_format=\"pandas\")\n", " return train_dataset, valid_dataset, test_dataset" ] }, { "cell_type": "markdown", "id": "8d6c6d17", "metadata": {}, "source": [ "The following function will create a Sklearn trainer, train it, and return the result." ] }, { "cell_type": "code", "execution_count": 11, "id": "0fd39e42", "metadata": {}, "outputs": [], "source": [ "def train_sklearn(num_cpus: int, use_gpu: bool = False) -> Result:\n", " if use_gpu and not cuMLRandomForestClassifier:\n", " raise RuntimeError(\"cuML must be installed for GPU enabled sklearn estimators.\")\n", "\n", " train_dataset, valid_dataset, _ = prepare_data()\n", "\n", " # Scale some random columns\n", " columns_to_scale = [\"mean radius\", \"mean texture\"]\n", " preprocessor = Chain(\n", " OrdinalEncoder([\"categorical_column\"]), StandardScaler(columns=columns_to_scale)\n", " )\n", "\n", " if use_gpu:\n", " trainer_resources = {\"CPU\": 1, \"GPU\": 1}\n", " estimator = cuMLRandomForestClassifier()\n", " else:\n", " trainer_resources = {\"CPU\": num_cpus}\n", " estimator = RandomForestClassifier()\n", "\n", " trainer = SklearnTrainer(\n", " estimator=estimator,\n", " label_column=\"target\",\n", " datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n", " preprocessor=preprocessor,\n", " cv=5,\n", " scaling_config={\n", " \"trainer_resources\": trainer_resources,\n", " },\n", " )\n", " result = trainer.fit()\n", " print(result.metrics)\n", "\n", " return result" ] }, { "cell_type": "markdown", "id": "7a2efb9d", "metadata": {}, "source": [ "Once we have the result, we can do batch inference on the obtained model. Let's define a utility function for this." ] }, { "cell_type": "code", "execution_count": 12, "id": "59eeadd8", "metadata": {}, "outputs": [], "source": [ "def predict_sklearn(result: Result, use_gpu: bool = False):\n", " _, _, test_dataset = prepare_data()\n", "\n", " batch_predictor = BatchPredictor.from_checkpoint(\n", " result.checkpoint, SklearnPredictor\n", " )\n", "\n", " predicted_labels = (\n", " batch_predictor.predict(\n", " test_dataset,\n", " num_gpus_per_worker=int(use_gpu),\n", " )\n", " .map_batches(lambda df: (df > 0.5).astype(int), batch_format=\"pandas\")\n", " )\n", " print(f\"PREDICTED LABELS\")\n", " predicted_labels.show()" ] }, { "cell_type": "markdown", "id": "7d073994", "metadata": {}, "source": [ "Now we can run the training:" ] }, { "cell_type": "code", "execution_count": 13, "id": "43f9170a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-06-22 17:27:37,741\tINFO services.py:1477 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://\u001b[39m\u001b[22m\n", "2022-06-22 17:27:39,822\tWARNING read_api.py:260 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n", "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 44.05it/s]\n" ] }, { "data": { "text/html": [ "== Status ==
Current time: 2022-06-22 17:27:59 (running for 00:00:18.31)
Memory usage on this node: 10.7/31.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/12.9 GiB heap, 0.0/6.45 GiB objects
Result logdir: /home/ubuntu/ray_results/SklearnTrainer_2022-06-22_17-27-40
Number of trials: 1/1 (1 TERMINATED)
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) fit_time
SklearnTrainer_9dec8_00000TERMINATED172.31.43.110:1492629 1 15.6842 2.31571

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(SklearnTrainer pid=1492629)\u001b[0m 2022-06-22 17:27:45,647\tWARNING pool.py:591 -- The 'context' argument is not supported using ray. Please refer to the documentation for how to control ray initialization.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Result for SklearnTrainer_9dec8_00000:\n", " cv:\n", " fit_time:\n", " - 2.221003770828247\n", " - 2.215489387512207\n", " - 2.2075674533843994\n", " - 2.222351312637329\n", " - 2.312389612197876\n", " fit_time_mean: 2.235760307312012\n", " fit_time_std: 0.03866614559685742\n", " score_time:\n", " - 0.022464990615844727\n", " - 0.0230865478515625\n", " - 0.02564835548400879\n", " - 0.029137849807739258\n", " - 0.021221637725830078\n", " score_time_mean: 0.02431187629699707\n", " score_time_std: 0.0028120522003997595\n", " test_score:\n", " - 0.9625\n", " - 0.9125\n", " - 0.9875\n", " - 1.0\n", " - 0.9367088607594937\n", " test_score_mean: 0.9598417721518986\n", " test_score_std: 0.032128186960552516\n", " date: 2022-06-22_17-27-59\n", " done: false\n", " experiment_id: f8215019c10e4a81ba2187c38e875365\n", " fit_time: 2.3157050609588623\n", " hostname: ip-172-31-43-110\n", " iterations_since_restore: 1\n", " node_ip:\n", " pid: 1492629\n", " should_checkpoint: true\n", " time_since_restore: 15.684244871139526\n", " time_this_iter_s: 15.684244871139526\n", " time_total_s: 15.684244871139526\n", " timestamp: 1655918879\n", " timesteps_since_restore: 0\n", " training_iteration: 1\n", " trial_id: 9dec8_00000\n", " valid:\n", " score_time: 0.03549623489379883\n", " test_score: 0.9532163742690059\n", " warmup_time: 0.0057866573333740234\n", " \n", "Result for SklearnTrainer_9dec8_00000:\n", " cv:\n", " fit_time:\n", " - 2.221003770828247\n", " - 2.215489387512207\n", " - 2.2075674533843994\n", " - 2.222351312637329\n", " - 2.312389612197876\n", " fit_time_mean: 2.235760307312012\n", " fit_time_std: 0.03866614559685742\n", " score_time:\n", " - 0.022464990615844727\n", " - 0.0230865478515625\n", " - 0.02564835548400879\n", " - 0.029137849807739258\n", " - 0.021221637725830078\n", " score_time_mean: 0.02431187629699707\n", " score_time_std: 0.0028120522003997595\n", " test_score:\n", " - 0.9625\n", " - 0.9125\n", " - 0.9875\n", " - 1.0\n", " - 0.9367088607594937\n", " test_score_mean: 0.9598417721518986\n", " test_score_std: 0.032128186960552516\n", " date: 2022-06-22_17-27-59\n", " done: true\n", " experiment_id: f8215019c10e4a81ba2187c38e875365\n", " experiment_tag: '0'\n", " fit_time: 2.3157050609588623\n", " hostname: ip-172-31-43-110\n", " iterations_since_restore: 1\n", " node_ip:\n", " pid: 1492629\n", " should_checkpoint: true\n", " time_since_restore: 15.684244871139526\n", " time_this_iter_s: 15.684244871139526\n", " time_total_s: 15.684244871139526\n", " timestamp: 1655918879\n", " timesteps_since_restore: 0\n", " training_iteration: 1\n", " trial_id: 9dec8_00000\n", " valid:\n", " score_time: 0.03549623489379883\n", " test_score: 0.9532163742690059\n", " warmup_time: 0.0057866573333740234\n", " \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-06-22 17:27:59,333\tINFO tune.py:734 -- Total run time: 19.09 seconds (18.31 seconds for the tuning loop).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'valid': {'score_time': 0.03549623489379883, 'test_score': 0.9532163742690059}, 'cv': {'fit_time': array([2.22100377, 2.21548939, 2.20756745, 2.22235131, 2.31238961]), 'score_time': array([0.02246499, 0.02308655, 0.02564836, 0.02913785, 0.02122164]), 'test_score': array([0.9625 , 0.9125 , 0.9875 , 1. , 0.93670886]), 'fit_time_mean': 2.235760307312012, 'fit_time_std': 0.03866614559685742, 'score_time_mean': 0.02431187629699707, 'score_time_std': 0.0028120522003997595, 'test_score_mean': 0.9598417721518986, 'test_score_std': 0.032128186960552516}, 'fit_time': 2.3157050609588623, 'time_this_iter_s': 15.684244871139526, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 1, 'trial_id': '9dec8_00000', 'experiment_id': 'f8215019c10e4a81ba2187c38e875365', 'date': '2022-06-22_17-27-59', 'timestamp': 1655918879, 'time_total_s': 15.684244871139526, 'pid': 1492629, 'hostname': 'ip-172-31-43-110', 'node_ip': '', 'config': {}, 'time_since_restore': 15.684244871139526, 'timesteps_since_restore': 0, 'iterations_since_restore': 1, 'warmup_time': 0.0057866573333740234, 'experiment_tag': '0'}\n" ] } ], "source": [ "result = train_sklearn(num_cpus=2, use_gpu=False)" ] }, { "cell_type": "markdown", "id": "0daba603", "metadata": {}, "source": [ "And perform inference on the obtained model:" ] }, { "cell_type": "code", "execution_count": 14, "id": "24b16ede", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-06-22 17:27:59,658\tWARNING read_api.py:260 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n", "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 64.73it/s]\n", "Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.60s/it]\n", "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 71.41it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "PREDICTED LABELS\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 0}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 0}\n", "{'predictions': 1}\n", "{'predictions': 0}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 0}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 1}\n", "{'predictions': 0}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "predict_sklearn(result, use_gpu=False)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "vscode": { "interpreter": { "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" } } }, "nbformat": 4, "nbformat_minor": 5 }