ray/doc/source/tune/examples/lightgbm_example.ipynb
Max Pumperla 372c620f58
[docs] Tune overhaul part II (#22656)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2022-02-26 23:07:34 -08:00

128 lines
No EOL
3.6 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"source": [
"(tune-lightgbm-example)=\n",
"\n",
"# Using LightGBM with Tune\n",
"\n",
"```{image} /images/lightgbm_logo.png\n",
":align: center\n",
":alt: LightGBM Logo\n",
":height: 120px\n",
":target: https://lightgbm.readthedocs.io\n",
"```\n",
"\n",
"```{contents}\n",
":backlinks: none\n",
":local: true\n",
"```\n",
"\n",
"## Example"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"import sklearn.datasets\n",
"import sklearn.metrics\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from ray import tune\n",
"from ray.tune.schedulers import ASHAScheduler\n",
"from ray.tune.integration.lightgbm import TuneReportCheckpointCallback\n",
"\n",
"\n",
"def train_breast_cancer(config):\n",
"\n",
" data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)\n",
" train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)\n",
" train_set = lgb.Dataset(train_x, label=train_y)\n",
" test_set = lgb.Dataset(test_x, label=test_y)\n",
" gbm = lgb.train(\n",
" config,\n",
" train_set,\n",
" valid_sets=[test_set],\n",
" valid_names=[\"eval\"],\n",
" verbose_eval=False,\n",
" callbacks=[\n",
" TuneReportCheckpointCallback(\n",
" {\n",
" \"binary_error\": \"eval-binary_error\",\n",
" \"binary_logloss\": \"eval-binary_logloss\",\n",
" }\n",
" )\n",
" ],\n",
" )\n",
" preds = gbm.predict(test_x)\n",
" pred_labels = np.rint(preds)\n",
" tune.report(\n",
" mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels), done=True\n",
" )\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" import argparse\n",
"\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\n",
" \"--server-address\",\n",
" type=str,\n",
" default=None,\n",
" required=False,\n",
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
" )\n",
" args, _ = parser.parse_known_args()\n",
"\n",
" if args.server_address:\n",
" import ray\n",
"\n",
" ray.init(f\"ray://{args.server_address}\")\n",
"\n",
" config = {\n",
" \"objective\": \"binary\",\n",
" \"metric\": [\"binary_error\", \"binary_logloss\"],\n",
" \"verbose\": -1,\n",
" \"boosting_type\": tune.grid_search([\"gbdt\", \"dart\"]),\n",
" \"num_leaves\": tune.randint(10, 1000),\n",
" \"learning_rate\": tune.loguniform(1e-8, 1e-1),\n",
" }\n",
"\n",
" analysis = tune.run(\n",
" train_breast_cancer,\n",
" metric=\"binary_error\",\n",
" mode=\"min\",\n",
" config=config,\n",
" num_samples=2,\n",
" scheduler=ASHAScheduler(),\n",
" )\n",
"\n",
" print(\"Best hyperparameters found were: \", analysis.best_config)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"orphan": true
},
"nbformat": 4,
"nbformat_minor": 5
}