{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "12ada6c3",
   "metadata": {},
   "source": [
    "(tune-lightgbm-example)=\n",
    "\n",
    "# Using LightGBM with Tune\n",
    "\n",
    "```{image} /images/lightgbm_logo.png\n",
    ":align: center\n",
    ":alt: LightGBM Logo\n",
    ":height: 120px\n",
    ":target: https://lightgbm.readthedocs.io\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b4c3f1e1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-07-22 15:30:02,623\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n",
      "2022-07-22 15:30:05,042\tWARNING function_trainable.py:619 -- Function checkpointing is disabled. This may result in unexpected behavior when using checkpointing features or certain schedulers. To enable, set the train function arguments to be `func(config, checkpoint_dir=None)`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "== Status ==<br>Current time: 2022-07-22 15:30:18 (running for 00:00:12.88)<br>Memory usage on this node: 10.1/16.0 GiB<br>Using AsyncHyperBand: num_stopped=4\n",
       "Bracket: Iter 64.000: -0.32867132867132864 | Iter 16.000: -0.32867132867132864 | Iter 4.000: -0.32867132867132864 | Iter 1.000: -0.35664335664335667<br>Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/5.3 GiB heap, 0.0/2.0 GiB objects<br>Current best trial: c7534_00003 with binary_error=0.3146853146853147 and parameters={'objective': 'binary', 'metric': ['binary_error', 'binary_logloss'], 'verbose': -1, 'boosting_type': 'dart', 'num_leaves': 702, 'learning_rate': 4.858514533326432e-08}<br>Result logdir: /Users/kai/ray_results/train_breast_cancer_2022-07-22_15-29-59<br>Number of trials: 4/4 (4 TERMINATED)<br><table>\n",
       "<thead>\n",
       "<tr><th>Trial name                     </th><th>status    </th><th>loc            </th><th>boosting_type  </th><th style=\"text-align: right;\">  learning_rate</th><th style=\"text-align: right;\">  num_leaves</th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">  binary_error</th><th style=\"text-align: right;\">  binary_logloss</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>train_breast_cancer_c7534_00000</td><td>TERMINATED</td><td>127.0.0.1:46947</td><td>gbdt           </td><td style=\"text-align: right;\">    1.09528e-08</td><td style=\"text-align: right;\">         926</td><td style=\"text-align: right;\">   100</td><td style=\"text-align: right;\">       4.04621  </td><td style=\"text-align: right;\">      0.370629</td><td style=\"text-align: right;\">        0.659303</td></tr>\n",
       "<tr><td>train_breast_cancer_c7534_00001</td><td>TERMINATED</td><td>127.0.0.1:46965</td><td>dart           </td><td style=\"text-align: right;\">    9.07058e-05</td><td style=\"text-align: right;\">         512</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">       0.0379331</td><td style=\"text-align: right;\">      0.391608</td><td style=\"text-align: right;\">        0.670769</td></tr>\n",
       "<tr><td>train_breast_cancer_c7534_00002</td><td>TERMINATED</td><td>127.0.0.1:46987</td><td>gbdt           </td><td style=\"text-align: right;\">    0.00110605 </td><td style=\"text-align: right;\">         186</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">       0.0196211</td><td style=\"text-align: right;\">      0.405594</td><td style=\"text-align: right;\">        0.678443</td></tr>\n",
       "<tr><td>train_breast_cancer_c7534_00003</td><td>TERMINATED</td><td>127.0.0.1:46988</td><td>dart           </td><td style=\"text-align: right;\">    4.85851e-08</td><td style=\"text-align: right;\">         702</td><td style=\"text-align: right;\">   100</td><td style=\"text-align: right;\">       0.417179 </td><td style=\"text-align: right;\">      0.314685</td><td style=\"text-align: right;\">        0.655626</td></tr>\n",
       "</tbody>\n",
       "</table><br><br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-07-22 15:30:06,224\tINFO plugin_schema_manager.py:52 -- Loading the default runtime env schemas: ['/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/working_dir_schema.json', '/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/pip_schema.json'].\n",
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46947)\u001b[0m /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46947)\u001b[0m   _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for train_breast_cancer_c7534_00000:\n",
      "  binary_error: 0.3706293706293706\n",
      "  binary_logloss: 0.6593043583564255\n",
      "  date: 2022-07-22_15-30-11\n",
      "  done: false\n",
      "  experiment_id: 9fbbf2cd94b24a14aa5ef2d552e78b70\n",
      "  hostname: Kais-MacBook-Pro.local\n",
      "  iterations_since_restore: 1\n",
      "  node_ip: 127.0.0.1\n",
      "  pid: 46947\n",
      "  time_since_restore: 0.10576009750366211\n",
      "  time_this_iter_s: 0.10576009750366211\n",
      "  time_total_s: 0.10576009750366211\n",
      "  timestamp: 1658500211\n",
      "  timesteps_since_restore: 0\n",
      "  training_iteration: 1\n",
      "  trial_id: c7534_00000\n",
      "  warmup_time: 0.0033888816833496094\n",
      "  \n",
      "Result for train_breast_cancer_c7534_00001:\n",
      "  binary_error: 0.3916083916083916\n",
      "  binary_logloss: 0.670769405026208\n",
      "  date: 2022-07-22_15-30-14\n",
      "  done: true\n",
      "  experiment_id: 10df796f3d2e4627ba7526014b21f426\n",
      "  hostname: Kais-MacBook-Pro.local\n",
      "  iterations_since_restore: 1\n",
      "  node_ip: 127.0.0.1\n",
      "  pid: 46965\n",
      "  time_since_restore: 0.0379331111907959\n",
      "  time_this_iter_s: 0.0379331111907959\n",
      "  time_total_s: 0.0379331111907959\n",
      "  timestamp: 1658500214\n",
      "  timesteps_since_restore: 0\n",
      "  training_iteration: 1\n",
      "  trial_id: c7534_00001\n",
      "  warmup_time: 0.0033578872680664062\n",
      "  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46965)\u001b[0m /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46965)\u001b[0m   _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for train_breast_cancer_c7534_00000:\n",
      "  binary_error: 0.3706293706293706\n",
      "  binary_logloss: 0.6593034612409915\n",
      "  date: 2022-07-22_15-30-15\n",
      "  done: true\n",
      "  experiment_id: 9fbbf2cd94b24a14aa5ef2d552e78b70\n",
      "  hostname: Kais-MacBook-Pro.local\n",
      "  iterations_since_restore: 100\n",
      "  node_ip: 127.0.0.1\n",
      "  pid: 46947\n",
      "  time_since_restore: 4.046205043792725\n",
      "  time_this_iter_s: 0.002338886260986328\n",
      "  time_total_s: 4.046205043792725\n",
      "  timestamp: 1658500215\n",
      "  timesteps_since_restore: 0\n",
      "  training_iteration: 100\n",
      "  trial_id: c7534_00000\n",
      "  warmup_time: 0.0033888816833496094\n",
      "  \n",
      "Result for train_breast_cancer_c7534_00003:\n",
      "  binary_error: 0.3146853146853147\n",
      "  binary_logloss: 0.635705942279978\n",
      "  date: 2022-07-22_15-30-18\n",
      "  done: false\n",
      "  experiment_id: d370b87343ea4a8e994bcf99a4f6f28d\n",
      "  hostname: Kais-MacBook-Pro.local\n",
      "  iterations_since_restore: 1\n",
      "  node_ip: 127.0.0.1\n",
      "  pid: 46988\n",
      "  time_since_restore: 0.04007911682128906\n",
      "  time_this_iter_s: 0.04007911682128906\n",
      "  time_total_s: 0.04007911682128906\n",
      "  timestamp: 1658500218\n",
      "  timesteps_since_restore: 0\n",
      "  training_iteration: 1\n",
      "  trial_id: c7534_00003\n",
      "  warmup_time: 0.0032351016998291016\n",
      "  \n",
      "Result for train_breast_cancer_c7534_00002:\n",
      "  binary_error: 0.40559440559440557\n",
      "  binary_logloss: 0.6784426899984863\n",
      "  date: 2022-07-22_15-30-18\n",
      "  done: true\n",
      "  experiment_id: 96e95ab236aa40aea3e9a1218293b562\n",
      "  hostname: Kais-MacBook-Pro.local\n",
      "  iterations_since_restore: 1\n",
      "  node_ip: 127.0.0.1\n",
      "  pid: 46987\n",
      "  time_since_restore: 0.01962113380432129\n",
      "  time_this_iter_s: 0.01962113380432129\n",
      "  time_total_s: 0.01962113380432129\n",
      "  timestamp: 1658500218\n",
      "  timesteps_since_restore: 0\n",
      "  training_iteration: 1\n",
      "  trial_id: c7534_00002\n",
      "  warmup_time: 0.0026988983154296875\n",
      "  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46987)\u001b[0m /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46987)\u001b[0m   _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n",
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46988)\u001b[0m /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n",
      "\u001b[2m\u001b[36m(train_breast_cancer pid=46988)\u001b[0m   _log_warning(\"'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for train_breast_cancer_c7534_00003:\n",
      "  binary_error: 0.3146853146853147\n",
      "  binary_logloss: 0.6556262981958247\n",
      "  date: 2022-07-22_15-30-18\n",
      "  done: true\n",
      "  experiment_id: d370b87343ea4a8e994bcf99a4f6f28d\n",
      "  hostname: Kais-MacBook-Pro.local\n",
      "  iterations_since_restore: 100\n",
      "  node_ip: 127.0.0.1\n",
      "  pid: 46988\n",
      "  time_since_restore: 0.4171791076660156\n",
      "  time_this_iter_s: 0.0024061203002929688\n",
      "  time_total_s: 0.4171791076660156\n",
      "  timestamp: 1658500218\n",
      "  timesteps_since_restore: 0\n",
      "  training_iteration: 100\n",
      "  trial_id: c7534_00003\n",
      "  warmup_time: 0.0032351016998291016\n",
      "  \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-07-22 15:30:18,873\tINFO tune.py:738 -- Total run time: 13.83 seconds (12.87 seconds for the tuning loop).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyperparameters found were:  {'objective': 'binary', 'metric': ['binary_error', 'binary_logloss'], 'verbose': -1, 'boosting_type': 'dart', 'num_leaves': 702, 'learning_rate': 4.858514533326432e-08}\n"
     ]
    }
   ],
   "source": [
    "import lightgbm as lgb\n",
    "import numpy as np\n",
    "import sklearn.datasets\n",
    "import sklearn.metrics\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from ray import tune\n",
    "from ray.air import session\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "from ray.tune.integration.lightgbm import TuneReportCheckpointCallback\n",
    "\n",
    "\n",
    "def train_breast_cancer(config):\n",
    "\n",
    "    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)\n",
    "    train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)\n",
    "    train_set = lgb.Dataset(train_x, label=train_y)\n",
    "    test_set = lgb.Dataset(test_x, label=test_y)\n",
    "    gbm = lgb.train(\n",
    "        config,\n",
    "        train_set,\n",
    "        valid_sets=[test_set],\n",
    "        valid_names=[\"eval\"],\n",
    "        verbose_eval=False,\n",
    "        callbacks=[\n",
    "            TuneReportCheckpointCallback(\n",
    "                {\n",
    "                    \"binary_error\": \"eval-binary_error\",\n",
    "                    \"binary_logloss\": \"eval-binary_logloss\",\n",
    "                }\n",
    "            )\n",
    "        ],\n",
    "    )\n",
    "    preds = gbm.predict(test_x)\n",
    "    pred_labels = np.rint(preds)\n",
    "    session.report({\n",
    "        \"mean_accuracy\": sklearn.metrics.accuracy_score(test_y, pred_labels), \"done\": True\n",
    "    })\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import argparse\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\n",
    "        \"--server-address\",\n",
    "        type=str,\n",
    "        default=None,\n",
    "        required=False,\n",
    "        help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
    "    )\n",
    "    args, _ = parser.parse_known_args()\n",
    "\n",
    "    if args.server_address:\n",
    "        import ray\n",
    "\n",
    "        ray.init(f\"ray://{args.server_address}\")\n",
    "\n",
    "    config = {\n",
    "        \"objective\": \"binary\",\n",
    "        \"metric\": [\"binary_error\", \"binary_logloss\"],\n",
    "        \"verbose\": -1,\n",
    "        \"boosting_type\": tune.grid_search([\"gbdt\", \"dart\"]),\n",
    "        \"num_leaves\": tune.randint(10, 1000),\n",
    "        \"learning_rate\": tune.loguniform(1e-8, 1e-1),\n",
    "    }\n",
    "    \n",
    "    tuner = tune.Tuner(\n",
    "        train_breast_cancer,\n",
    "        tune_config=tune.TuneConfig(\n",
    "            metric=\"binary_error\",\n",
    "            mode=\"min\",\n",
    "            scheduler=ASHAScheduler(),\n",
    "            num_samples=2,\n",
    "        ),\n",
    "        param_space=config,\n",
    "    )\n",
    "    results = tuner.fit()\n",
    "\n",
    "    print(\"Best hyperparameters found were: \", results.get_best_result().config)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}