mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00

The package "ml" should be renamed to "air". Main question: Keep a `ml.py` with `from ray.air import *` for some level of backwards compatibility? I'd go for no to force people to use the new structure.
496 lines
27 KiB
Text
496 lines
27 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0d385409",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Training a model with distributed LightGBM\n",
|
|
"In this example we will train a model in Ray Air using distributed LightGBM."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "07d92cee",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's start with installing our dependencies:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "86131abe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install -qU \"ray[tune]\" lightgbm_ray"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "135fc884",
|
|
"metadata": {},
|
|
"source": [
|
|
"Then we need some imports:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "102ef1ac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import argparse\n",
|
|
"import math\n",
|
|
"from typing import Tuple\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"import ray\n",
|
|
"from ray.air.batch_predictor import BatchPredictor\n",
|
|
"from ray.air.predictors.integrations.lightgbm import LightGBMPredictor\n",
|
|
"from ray.air.preprocessors.chain import Chain\n",
|
|
"from ray.air.preprocessors.encoder import Categorizer\n",
|
|
"from ray.air.train.integrations.lightgbm import LightGBMTrainer\n",
|
|
"from ray.data.dataset import Dataset\n",
|
|
"from ray.air.result import Result\n",
|
|
"from ray.air.preprocessors import StandardScaler\n",
|
|
"from sklearn.datasets import load_breast_cancer\n",
|
|
"from sklearn.model_selection import train_test_split"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c7d102bd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next we define a function to load our train, validation, and test datasets."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "f1f35cd7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:\n",
|
|
" data_raw = load_breast_cancer()\n",
|
|
" dataset_df = pd.DataFrame(data_raw[\"data\"], columns=data_raw[\"feature_names\"])\n",
|
|
" dataset_df[\"target\"] = data_raw[\"target\"]\n",
|
|
" # add a random categorical column\n",
|
|
" num_samples = len(dataset_df)\n",
|
|
" dataset_df[\"categorical_column\"] = pd.Series(\n",
|
|
" ([\"A\", \"B\"] * math.ceil(num_samples / 2))[:num_samples]\n",
|
|
" )\n",
|
|
" train_df, test_df = train_test_split(dataset_df, test_size=0.3)\n",
|
|
" train_dataset = ray.data.from_pandas(train_df)\n",
|
|
" valid_dataset = ray.data.from_pandas(test_df)\n",
|
|
" test_dataset = ray.data.from_pandas(test_df.drop(\"target\", axis=1))\n",
|
|
" return train_dataset, valid_dataset, test_dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8f7afbce",
|
|
"metadata": {},
|
|
"source": [
|
|
"The following function will create a LightGBM trainer, train it, and return the result."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "fefcbc8a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_lightgbm(num_workers: int, use_gpu: bool = False) -> Result:\n",
|
|
" train_dataset, valid_dataset, _ = prepare_data()\n",
|
|
"\n",
|
|
" # Scale some random columns, and categorify the categorical_column,\n",
|
|
" # allowing LightGBM to use its built-in categorical feature support\n",
|
|
" columns_to_scale = [\"mean radius\", \"mean texture\"]\n",
|
|
" preprocessor = Chain(\n",
|
|
" Categorizer([\"categorical_column\"]), StandardScaler(columns=columns_to_scale)\n",
|
|
" )\n",
|
|
"\n",
|
|
" # LightGBM specific params\n",
|
|
" params = {\n",
|
|
" \"objective\": \"binary\",\n",
|
|
" \"metric\": [\"binary_logloss\", \"binary_error\"],\n",
|
|
" }\n",
|
|
"\n",
|
|
" trainer = LightGBMTrainer(\n",
|
|
" scaling_config={\n",
|
|
" \"num_workers\": num_workers,\n",
|
|
" \"use_gpu\": use_gpu,\n",
|
|
" },\n",
|
|
" label_column=\"target\",\n",
|
|
" params=params,\n",
|
|
" datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n",
|
|
" preprocessor=preprocessor,\n",
|
|
" num_boost_round=100,\n",
|
|
" )\n",
|
|
" result = trainer.fit()\n",
|
|
" print(result.metrics)\n",
|
|
"\n",
|
|
" return result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "04d278ae",
|
|
"metadata": {},
|
|
"source": [
|
|
"Once we have the result, we can do batch inference on the obtained model. Let's define a utility function for this."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "3f1d0c19",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict_lightgbm(result: Result):\n",
|
|
" _, _, test_dataset = prepare_data()\n",
|
|
" batch_predictor = BatchPredictor.from_checkpoint(\n",
|
|
" result.checkpoint, LightGBMPredictor\n",
|
|
" )\n",
|
|
"\n",
|
|
" predicted_labels = (\n",
|
|
" batch_predictor.predict(test_dataset)\n",
|
|
" .map_batches(lambda df: (df > 0.5).astype(int), batch_format=\"pandas\")\n",
|
|
" .to_pandas(limit=float(\"inf\"))\n",
|
|
" )\n",
|
|
" print(f\"PREDICTED LABELS\\n{predicted_labels}\")\n",
|
|
"\n",
|
|
" shap_values = batch_predictor.predict(test_dataset, pred_contrib=True).to_pandas(\n",
|
|
" limit=float(\"inf\")\n",
|
|
" )\n",
|
|
" print(f\"SHAP VALUES\\n{shap_values}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2bb0e5df",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we can run the training:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "8244ff3c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-05-19 11:18:27,652\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"== Status ==<br>Current time: 2022-05-19 11:18:47 (running for 00:00:15.19)<br>Memory usage on this node: 10.2/16.0 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.86 GiB heap, 0.0/2.0 GiB objects<br>Result logdir: /Users/kai/ray_results/LightGBMTrainer_2022-05-19_11-18-30<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
|
|
"<thead>\n",
|
|
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> train-binary_logloss</th><th style=\"text-align: right;\"> train-binary_error</th><th style=\"text-align: right;\"> valid-binary_logloss</th></tr>\n",
|
|
"</thead>\n",
|
|
"<tbody>\n",
|
|
"<tr><td>LightGBMTrainer_07bf3_00000</td><td>TERMINATED</td><td>127.0.0.1:9219</td><td style=\"text-align: right;\"> 100</td><td style=\"text-align: right;\"> 10.4622</td><td style=\"text-align: right;\"> 0.000197893</td><td style=\"text-align: right;\"> 0</td><td style=\"text-align: right;\"> 0.289033</td></tr>\n",
|
|
"</tbody>\n",
|
|
"</table><br><br>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"UserWarning: cpus_per_actor is set to less than 2. Distributed LightGBM needs at least 2 CPUs per actor to train efficiently. This may lead to a degradation of performance during training.\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:32,940\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=16 --runtime-env-hash=-2010331134\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:36,664\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=17 --runtime-env-hash=-2010331069\n",
|
|
"\u001b[2m\u001b[36m(GBDTTrainable pid=9219)\u001b[0m UserWarning: Dataset 'train' has 1 blocks, which is less than the `num_workers` 2. This dataset will be automatically repartitioned to 2 blocks.\n",
|
|
"\u001b[2m\u001b[36m(GBDTTrainable pid=9219)\u001b[0m UserWarning: Dataset 'valid' has 1 blocks, which is less than the `num_workers` 2. This dataset will be automatically repartitioned to 2 blocks.\n",
|
|
"\u001b[2m\u001b[36m(GBDTTrainable pid=9219)\u001b[0m UserWarning: cpus_per_actor is set to less than 2. Distributed LightGBM needs at least 2 CPUs per actor to train efficiently. This may lead to a degradation of performance during training.\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:38,980\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=18 --runtime-env-hash=-2010331069\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:38,997\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=19 --runtime-env-hash=-2010331069\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:39,091\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=21 --runtime-env-hash=-2010331134\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:39,095\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=20 --runtime-env-hash=-2010331134\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:39,107\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=23 --runtime-env-hash=-2010331134\n",
|
|
"\u001b[2m\u001b[33m(raylet)\u001b[0m 2022-05-19 11:18:39,107\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=51840 --object-store-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-18-25_114449_9132/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=56443 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:58688 --redis-password=5241590000000000 --startup-token=22 --runtime-env-hash=-2010331134\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Info] Trying to bind port 52127...\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Info] Binding port 52127 succeeded\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Info] Listening...\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Warning] Connecting to rank 1 failed, waiting for 200 milliseconds\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m [LightGBM] [Info] Trying to bind port 52128...\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m [LightGBM] [Info] Binding port 52128 succeeded\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m [LightGBM] [Info] Listening...\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Info] Connected to rank 1\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Info] Local rank: 0, total number of machines: 2\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m [LightGBM] [Warning] num_threads is set=1, n_jobs=-1 will be ignored. Current value: num_threads=1\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m [LightGBM] [Info] Connected to rank 0\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m [LightGBM] [Info] Local rank: 1, total number of machines: 2\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m [LightGBM] [Warning] num_threads is set=1, n_jobs=-1 will be ignored. Current value: num_threads=1\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m UserWarning: Overriding the parameters from Reference Dataset.\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9243)\u001b[0m UserWarning: categorical_column in param dict is overridden.\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m UserWarning: Overriding the parameters from Reference Dataset.\n",
|
|
"\u001b[2m\u001b[36m(_RemoteRayLightGBMActor pid=9242)\u001b[0m UserWarning: categorical_column in param dict is overridden.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Result for LightGBMTrainer_07bf3_00000:\n",
|
|
" date: 2022-05-19_11-18-44\n",
|
|
" done: false\n",
|
|
" experiment_id: 1d3640d1c3a743aeae7274a0ce253107\n",
|
|
" hostname: Kais-MacBook-Pro.local\n",
|
|
" iterations_since_restore: 1\n",
|
|
" node_ip: 127.0.0.1\n",
|
|
" pid: 9219\n",
|
|
" should_checkpoint: true\n",
|
|
" time_since_restore: 8.41084909439087\n",
|
|
" time_this_iter_s: 8.41084909439087\n",
|
|
" time_total_s: 8.41084909439087\n",
|
|
" timestamp: 1652955524\n",
|
|
" timesteps_since_restore: 0\n",
|
|
" train-binary_error: 0.36683417085427134\n",
|
|
" train-binary_logloss: 0.5804693664919086\n",
|
|
" training_iteration: 1\n",
|
|
" trial_id: 07bf3_00000\n",
|
|
" valid-binary_error: 0.36470588235294116\n",
|
|
" valid-binary_logloss: 0.5868466345817073\n",
|
|
" warmup_time: 0.004106044769287109\n",
|
|
" \n",
|
|
"Result for LightGBMTrainer_07bf3_00000:\n",
|
|
" date: 2022-05-19_11-18-46\n",
|
|
" done: true\n",
|
|
" experiment_id: 1d3640d1c3a743aeae7274a0ce253107\n",
|
|
" experiment_tag: '0'\n",
|
|
" hostname: Kais-MacBook-Pro.local\n",
|
|
" iterations_since_restore: 100\n",
|
|
" node_ip: 127.0.0.1\n",
|
|
" pid: 9219\n",
|
|
" should_checkpoint: true\n",
|
|
" time_since_restore: 10.46218204498291\n",
|
|
" time_this_iter_s: 0.025421857833862305\n",
|
|
" time_total_s: 10.46218204498291\n",
|
|
" timestamp: 1652955526\n",
|
|
" timesteps_since_restore: 0\n",
|
|
" train-binary_error: 0.0\n",
|
|
" train-binary_logloss: 0.00019789273681613937\n",
|
|
" training_iteration: 100\n",
|
|
" trial_id: 07bf3_00000\n",
|
|
" valid-binary_error: 0.058823529411764705\n",
|
|
" valid-binary_logloss: 0.2890328865004496\n",
|
|
" warmup_time: 0.004106044769287109\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-05-19 11:18:47,218\tINFO tune.py:753 -- Total run time: 16.87 seconds (15.17 seconds for the tuning loop).\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'train-binary_logloss': 0.00019789273681613937, 'train-binary_error': 0.0, 'valid-binary_logloss': 0.2890328865004496, 'valid-binary_error': 0.058823529411764705, 'time_this_iter_s': 0.025421857833862305, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 100, 'trial_id': '07bf3_00000', 'experiment_id': '1d3640d1c3a743aeae7274a0ce253107', 'date': '2022-05-19_11-18-46', 'timestamp': 1652955526, 'time_total_s': 10.46218204498291, 'pid': 9219, 'hostname': 'Kais-MacBook-Pro.local', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 10.46218204498291, 'timesteps_since_restore': 0, 'iterations_since_restore': 100, 'warmup_time': 0.004106044769287109, 'experiment_tag': '0'}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"result = train_lightgbm(num_workers=2, use_gpu=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d7155d9b",
|
|
"metadata": {},
|
|
"source": [
|
|
"And perform inference on the obtained model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "871c9be6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.21s/it]\n",
|
|
"Map_Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 93.04it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PREDICTED LABELS\n",
|
|
" predictions\n",
|
|
"0 1\n",
|
|
"1 1\n",
|
|
"2 1\n",
|
|
"3 1\n",
|
|
"4 1\n",
|
|
".. ...\n",
|
|
"166 1\n",
|
|
"167 0\n",
|
|
"168 1\n",
|
|
"169 1\n",
|
|
"170 1\n",
|
|
"\n",
|
|
"[171 rows x 1 columns]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.20s/it]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"SHAP VALUES\n",
|
|
" predictions_0 predictions_1 predictions_2 predictions_3 \\\n",
|
|
"0 0.089310 -0.909119 0.042819 0.002084 \n",
|
|
"1 0.080590 -0.961430 0.043946 -0.014364 \n",
|
|
"2 0.080606 -0.778903 0.067703 0.005561 \n",
|
|
"3 0.095129 -0.281614 0.083395 0.005946 \n",
|
|
"4 0.085822 0.470362 0.106340 -0.014601 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 0.095845 0.217879 0.131259 0.045455 \n",
|
|
"167 -0.369657 -1.825973 -0.270361 -0.005203 \n",
|
|
"168 0.078703 0.142254 0.069414 0.002620 \n",
|
|
"169 0.069391 0.226548 0.035343 -0.014900 \n",
|
|
"170 0.088893 0.216221 0.068534 -0.015672 \n",
|
|
"\n",
|
|
" predictions_4 predictions_5 predictions_6 predictions_7 \\\n",
|
|
"0 -0.243459 0.536332 -1.275628 0.490998 \n",
|
|
"1 0.888434 -0.666081 0.541587 1.229128 \n",
|
|
"2 -0.444282 -0.886527 0.615167 1.113725 \n",
|
|
"3 0.292466 -0.205161 0.606795 0.971005 \n",
|
|
"4 0.549158 -0.405526 0.601660 0.438182 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 -0.177095 0.418596 0.649067 0.853719 \n",
|
|
"167 0.248983 0.327855 0.028472 -0.396260 \n",
|
|
"168 0.398207 0.217577 0.655739 0.779261 \n",
|
|
"169 -0.395022 -1.123284 0.555980 0.367841 \n",
|
|
"170 0.478937 0.223053 -1.091246 0.506502 \n",
|
|
"\n",
|
|
" predictions_8 predictions_9 ... predictions_22 predictions_23 \\\n",
|
|
"0 0.245076 0.145549 ... 0.635400 1.961269 \n",
|
|
"1 -0.344460 0.013988 ... 0.676673 1.151546 \n",
|
|
"2 0.179933 0.109160 ... 0.689301 1.877165 \n",
|
|
"3 0.172837 0.195020 ... 0.718073 1.592259 \n",
|
|
"4 0.182619 -0.127179 ... 0.682158 1.121333 \n",
|
|
".. ... ... ... ... ... \n",
|
|
"166 0.243637 0.070938 ... 0.744151 1.771634 \n",
|
|
"167 0.124145 -0.493628 ... -0.541835 -2.201714 \n",
|
|
"168 0.179598 0.192144 ... 0.754423 1.300675 \n",
|
|
"169 -0.270785 0.072316 ... 0.531150 0.803271 \n",
|
|
"170 0.098758 0.163663 ... 0.545371 1.010722 \n",
|
|
"\n",
|
|
" predictions_24 predictions_25 predictions_26 predictions_27 \\\n",
|
|
"0 0.251278 0.089320 -0.598509 1.552746 \n",
|
|
"1 0.284339 0.053013 0.818982 1.595459 \n",
|
|
"2 0.564781 -0.043412 0.613467 1.649626 \n",
|
|
"3 0.532693 0.070353 0.544844 1.746157 \n",
|
|
"4 0.311470 0.115707 1.179717 1.670017 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 -0.375872 0.127995 0.817884 1.686068 \n",
|
|
"167 0.625508 0.039757 -0.539761 0.908562 \n",
|
|
"168 0.398768 0.118736 0.736576 1.693371 \n",
|
|
"169 0.266529 -0.005190 0.850678 1.707049 \n",
|
|
"170 0.877337 0.080191 -0.588513 1.725231 \n",
|
|
"\n",
|
|
" predictions_28 predictions_29 predictions_30 predictions_31 \n",
|
|
"0 -0.025840 0.106845 -0.037509 0.636929 \n",
|
|
"1 0.035559 -0.119623 -0.016733 0.636929 \n",
|
|
"2 0.043654 0.070261 -0.012863 0.636929 \n",
|
|
"3 0.082365 0.098007 0.024749 0.636929 \n",
|
|
"4 0.079768 -0.008935 0.035418 0.636929 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 0.043644 -0.044403 -0.038294 0.636929 \n",
|
|
"167 0.005756 -0.192442 -0.012280 0.636929 \n",
|
|
"168 0.054687 0.095692 -0.012863 0.636929 \n",
|
|
"169 0.086557 -0.104739 -0.016732 0.636929 \n",
|
|
"170 -0.319243 0.137260 0.023966 0.636929 \n",
|
|
"\n",
|
|
"[171 rows x 32 columns]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predict_lightgbm(result)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|