mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
128 lines
No EOL
3.6 KiB
Text
128 lines
No EOL
3.6 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"(tune-lightgbm-example)=\n",
|
|
"\n",
|
|
"# Using LightGBM with Tune\n",
|
|
"\n",
|
|
"```{image} /images/lightgbm_logo.png\n",
|
|
":align: center\n",
|
|
":alt: LightGBM Logo\n",
|
|
":height: 120px\n",
|
|
":target: https://lightgbm.readthedocs.io\n",
|
|
"```\n",
|
|
"\n",
|
|
"```{contents}\n",
|
|
":backlinks: none\n",
|
|
":local: true\n",
|
|
"```\n",
|
|
"\n",
|
|
"## Example"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"import lightgbm as lgb\n",
|
|
"import numpy as np\n",
|
|
"import sklearn.datasets\n",
|
|
"import sklearn.metrics\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"from ray import tune\n",
|
|
"from ray.tune.schedulers import ASHAScheduler\n",
|
|
"from ray.tune.integration.lightgbm import TuneReportCheckpointCallback\n",
|
|
"\n",
|
|
"\n",
|
|
"def train_breast_cancer(config):\n",
|
|
"\n",
|
|
" data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)\n",
|
|
" train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)\n",
|
|
" train_set = lgb.Dataset(train_x, label=train_y)\n",
|
|
" test_set = lgb.Dataset(test_x, label=test_y)\n",
|
|
" gbm = lgb.train(\n",
|
|
" config,\n",
|
|
" train_set,\n",
|
|
" valid_sets=[test_set],\n",
|
|
" valid_names=[\"eval\"],\n",
|
|
" verbose_eval=False,\n",
|
|
" callbacks=[\n",
|
|
" TuneReportCheckpointCallback(\n",
|
|
" {\n",
|
|
" \"binary_error\": \"eval-binary_error\",\n",
|
|
" \"binary_logloss\": \"eval-binary_logloss\",\n",
|
|
" }\n",
|
|
" )\n",
|
|
" ],\n",
|
|
" )\n",
|
|
" preds = gbm.predict(test_x)\n",
|
|
" pred_labels = np.rint(preds)\n",
|
|
" tune.report(\n",
|
|
" mean_accuracy=sklearn.metrics.accuracy_score(test_y, pred_labels), done=True\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" import argparse\n",
|
|
"\n",
|
|
" parser = argparse.ArgumentParser()\n",
|
|
" parser.add_argument(\n",
|
|
" \"--server-address\",\n",
|
|
" type=str,\n",
|
|
" default=None,\n",
|
|
" required=False,\n",
|
|
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
|
|
" )\n",
|
|
" args, _ = parser.parse_known_args()\n",
|
|
"\n",
|
|
" if args.server_address:\n",
|
|
" import ray\n",
|
|
"\n",
|
|
" ray.init(f\"ray://{args.server_address}\")\n",
|
|
"\n",
|
|
" config = {\n",
|
|
" \"objective\": \"binary\",\n",
|
|
" \"metric\": [\"binary_error\", \"binary_logloss\"],\n",
|
|
" \"verbose\": -1,\n",
|
|
" \"boosting_type\": tune.grid_search([\"gbdt\", \"dart\"]),\n",
|
|
" \"num_leaves\": tune.randint(10, 1000),\n",
|
|
" \"learning_rate\": tune.loguniform(1e-8, 1e-1),\n",
|
|
" }\n",
|
|
"\n",
|
|
" analysis = tune.run(\n",
|
|
" train_breast_cancer,\n",
|
|
" metric=\"binary_error\",\n",
|
|
" mode=\"min\",\n",
|
|
" config=config,\n",
|
|
" num_samples=2,\n",
|
|
" scheduler=ASHAScheduler(),\n",
|
|
" )\n",
|
|
"\n",
|
|
" print(\"Best hyperparameters found were: \", analysis.best_config)\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"orphan": true
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |