{ "cells": [ { "cell_type": "markdown", "source": [ "(tune-lightgbm-example)=\n", "\n", "# Using LightGBM with Tune\n", "\n", "```{image} /images/lightgbm_logo.png\n", ":align: center\n", ":alt: LightGBM Logo\n", ":height: 120px\n", ":target: https://lightgbm.readthedocs.io\n", "```\n", "\n", "```{contents}\n", ":backlinks: none\n", ":local: true\n", "```\n", "\n", "## Example" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "import lightgbm as lgb\n", "import numpy as np\n", "import sklearn.datasets\n", "import sklearn.metrics\n", "from sklearn.model_selection import train_test_split\n", "\n", "from ray import tune\n", "from ray.tune.schedulers import ASHAScheduler\n", "from ray.tune.integration.lightgbm import TuneReportCheckpointCallback\n", "\n", "\n", "def train_breast_cancer(config):\n", "\n", " data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)\n", " train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)\n", " train_set = lgb.Dataset(train_x, label=train_y)\n", " test_set = lgb.Dataset(test_x, label=test_y)\n", " gbm = lgb.train(\n", " config,\n", " train_set,\n", " valid_sets=[test_set],\n", " valid_names=[\"eval\"],\n", " verbose_eval=False,\n", " callbacks=[\n", " TuneReportCheckpointCallback(\n", " {\n", " \"binary_error\": \"eval-binary_error\",\n", " \"binary_logloss\": \"eval-binary_logloss\",\n", " }\n", " )\n", " ],\n", " )\n", " preds = gbm.predict(test_x)\n", " pred_labels = np.rint(preds)\n", " tune.report(\n", " mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels), done=True\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " import argparse\n", "\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument(\n", " \"--server-address\",\n", " type=str,\n", " default=None,\n", " required=False,\n", " help=\"The address of server to connect to if using \" \"Ray Client.\",\n", " )\n", " args, _ = parser.parse_known_args()\n", "\n", " if args.server_address:\n", " import ray\n", "\n", " ray.init(f\"ray://{args.server_address}\")\n", "\n", " config = {\n", " \"objective\": \"binary\",\n", " \"metric\": [\"binary_error\", \"binary_logloss\"],\n", " \"verbose\": -1,\n", " \"boosting_type\": tune.grid_search([\"gbdt\", \"dart\"]),\n", " \"num_leaves\": tune.randint(10, 1000),\n", " \"learning_rate\": tune.loguniform(1e-8, 1e-1),\n", " }\n", "\n", " analysis = tune.run(\n", " train_breast_cancer,\n", " metric=\"binary_error\",\n", " mode=\"min\",\n", " config=config,\n", " num_samples=2,\n", " scheduler=ASHAScheduler(),\n", " )\n", "\n", " print(\"Best hyperparameters found were: \", analysis.best_config)\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 5 }