mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

The package "ml" should be renamed to "air". Main question: Keep a `ml.py` with `from ray.air import *` for some level of backwards compatibility? I'd go for no to force people to use the new structure.
513 lines
No EOL
25 KiB
Text
513 lines
No EOL
25 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5fb89b3d",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"# Training a model with distributed XGBoost\n",
|
|
"In this example we will train a model in Ray Air using distributed XGBoost."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "53d57c1f",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"Let's start with installing our dependencies:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "41f20cc1",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install -qU \"ray[tune]\" xgboost_ray"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d2fe8d4a",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"Then we need some imports:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "7232303d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import argparse\n",
|
|
"from typing import Tuple\n",
|
|
"\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"import ray\n",
|
|
"from ray.air.batch_predictor import BatchPredictor\n",
|
|
"from ray.air.predictors.integrations.xgboost import XGBoostPredictor\n",
|
|
"from ray.air.train.integrations.xgboost import XGBoostTrainer\n",
|
|
"from ray.data.dataset import Dataset\n",
|
|
"from ray.air.result import Result\n",
|
|
"from ray.air.preprocessors import StandardScaler\n",
|
|
"from sklearn.datasets import load_breast_cancer\n",
|
|
"from sklearn.model_selection import train_test_split"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1c75b5ca",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"Next we define a function to load our train, validation, and test datasets."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "37c4f38f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:\n",
|
|
" data_raw = load_breast_cancer()\n",
|
|
" dataset_df = pd.DataFrame(data_raw[\"data\"], columns=data_raw[\"feature_names\"])\n",
|
|
" dataset_df[\"target\"] = data_raw[\"target\"]\n",
|
|
" train_df, test_df = train_test_split(dataset_df, test_size=0.3)\n",
|
|
" train_dataset = ray.data.from_pandas(train_df)\n",
|
|
" valid_dataset = ray.data.from_pandas(test_df)\n",
|
|
" test_dataset = ray.data.from_pandas(test_df.drop(\"target\", axis=1))\n",
|
|
" return train_dataset, valid_dataset, test_dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9b2850dd",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"The following function will create a XGBoost trainer, train it, and return the result."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "dae8998d",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_xgboost(num_workers: int, use_gpu: bool = False) -> Result:\n",
|
|
" train_dataset, valid_dataset, _ = prepare_data()\n",
|
|
"\n",
|
|
" # Scale some random columns\n",
|
|
" columns_to_scale = [\"mean radius\", \"mean texture\"]\n",
|
|
" preprocessor = StandardScaler(columns=columns_to_scale)\n",
|
|
"\n",
|
|
" # XGBoost specific params\n",
|
|
" params = {\n",
|
|
" \"tree_method\": \"approx\",\n",
|
|
" \"objective\": \"binary:logistic\",\n",
|
|
" \"eval_metric\": [\"logloss\", \"error\"],\n",
|
|
" }\n",
|
|
"\n",
|
|
" trainer = XGBoostTrainer(\n",
|
|
" scaling_config={\n",
|
|
" \"num_workers\": num_workers,\n",
|
|
" \"use_gpu\": use_gpu,\n",
|
|
" },\n",
|
|
" label_column=\"target\",\n",
|
|
" params=params,\n",
|
|
" datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n",
|
|
" preprocessor=preprocessor,\n",
|
|
" num_boost_round=100,\n",
|
|
" )\n",
|
|
" result = trainer.fit()\n",
|
|
" print(result.metrics)\n",
|
|
"\n",
|
|
" return result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ce05af87",
|
|
"metadata": {},
|
|
"source": [
|
|
"Once we have the result, we can do batch inference on the obtained model. Let's define a utility function for this."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "5b8076d3",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict_xgboost(result: Result):\n",
|
|
" _, _, test_dataset = prepare_data()\n",
|
|
"\n",
|
|
" batch_predictor = BatchPredictor.from_checkpoint(\n",
|
|
" result.checkpoint, XGBoostPredictor\n",
|
|
" )\n",
|
|
"\n",
|
|
" predicted_labels = (\n",
|
|
" batch_predictor.predict(test_dataset)\n",
|
|
" .map_batches(lambda df: (df > 0.5).astype(int), batch_format=\"pandas\")\n",
|
|
" .to_pandas(limit=float(\"inf\"))\n",
|
|
" )\n",
|
|
" print(f\"PREDICTED LABELS\\n{predicted_labels}\")\n",
|
|
"\n",
|
|
" shap_values = batch_predictor.predict(test_dataset, pred_contribs=True).to_pandas(\n",
|
|
" limit=float(\"inf\")\n",
|
|
" )\n",
|
|
" print(f\"SHAP VALUES\\n{shap_values}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7e172f66",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we can run the training:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "0f96d62b",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-05-19 11:44:42,413\tINFO services.py:1483 -- View the Ray dashboard at \u001B[1m\u001B[32mhttp://127.0.0.1:8265\u001B[39m\u001B[22m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"== Status ==<br>Current time: 2022-05-19 11:45:00 (running for 00:00:13.93)<br>Memory usage on this node: 10.3/16.0 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.5 GiB heap, 0.0/2.0 GiB objects<br>Result logdir: /Users/kai/ray_results/XGBoostTrainer_2022-05-19_11-44-45<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
|
|
"<thead>\n",
|
|
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> train-logloss</th><th style=\"text-align: right;\"> train-error</th><th style=\"text-align: right;\"> valid-logloss</th></tr>\n",
|
|
"</thead>\n",
|
|
"<tbody>\n",
|
|
"<tr><td>XGBoostTrainer_b273b_00000</td><td>TERMINATED</td><td>127.0.0.1:11036</td><td style=\"text-align: right;\"> 100</td><td style=\"text-align: right;\"> 9.03935</td><td style=\"text-align: right;\"> 0.005949</td><td style=\"text-align: right;\"> 0</td><td style=\"text-align: right;\"> 0.07483</td></tr>\n",
|
|
"</tbody>\n",
|
|
"</table><br><br>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:47,554\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=16 --runtime-env-hash=-2010331134\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:51,603\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=17 --runtime-env-hash=-2010331069\n",
|
|
"\u001B[2m\u001B[36m(GBDTTrainable pid=11036)\u001B[0m UserWarning: Dataset 'train' has 1 blocks, which is less than the `num_workers` 2. This dataset will be automatically repartitioned to 2 blocks.\n",
|
|
"\u001B[2m\u001B[36m(GBDTTrainable pid=11036)\u001B[0m UserWarning: Dataset 'valid' has 1 blocks, which is less than the `num_workers` 2. This dataset will be automatically repartitioned to 2 blocks.\n",
|
|
"\u001B[2m\u001B[36m(GBDTTrainable pid=11036)\u001B[0m 2022-05-19 11:44:53,035\tINFO main.py:980 -- [RayXGBoost] Created 2 new actors (2 total actors). Waiting until actors are ready for training.\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:54,085\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=18 --runtime-env-hash=-2010331069\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:54,106\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=19 --runtime-env-hash=-2010331069\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:54,252\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=21 --runtime-env-hash=-2010331134\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:54,266\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=23 --runtime-env-hash=-2010331134\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:54,266\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=20 --runtime-env-hash=-2010331134\n",
|
|
"\u001B[2m\u001B[33m(raylet)\u001B[0m 2022-05-19 11:44:54,271\tINFO context.py:70 -- Exec'ing worker with command: exec /Users/kai/.pyenv/versions/3.7.7/bin/python3.7 /Users/kai/coding/ray/python/ray/workers/default_worker.py --node-ip-address=127.0.0.1 --node-manager-port=54067 --object-store-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/plasma_store --raylet-name=/tmp/ray/session_2022-05-19_11-44-39_813259_10959/sockets/raylet --redis-address=None --storage=None --temp-dir=/tmp/ray --metrics-agent-port=61242 --logging-rotate-bytes=536870912 --logging-rotate-backup-count=5 --gcs-address=127.0.0.1:61017 --redis-password=5241590000000000 --startup-token=22 --runtime-env-hash=-2010331134\n",
|
|
"\u001B[2m\u001B[36m(GBDTTrainable pid=11036)\u001B[0m 2022-05-19 11:44:56,874\tINFO main.py:1025 -- [RayXGBoost] Starting XGBoost training.\n",
|
|
"\u001B[2m\u001B[36m(_RemoteRayXGBoostActor pid=11104)\u001B[0m [11:44:56] task [xgboost.ray]:4517180944 got new rank 1\n",
|
|
"\u001B[2m\u001B[36m(_RemoteRayXGBoostActor pid=11103)\u001B[0m [11:44:56] task [xgboost.ray]:4655847056 got new rank 0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Result for XGBoostTrainer_b273b_00000:\n",
|
|
" date: 2022-05-19_11-44-57\n",
|
|
" done: false\n",
|
|
" experiment_id: 991235d8b76649398688695ca70a08e4\n",
|
|
" hostname: Kais-MacBook-Pro.local\n",
|
|
" iterations_since_restore: 1\n",
|
|
" node_ip: 127.0.0.1\n",
|
|
" pid: 11036\n",
|
|
" should_checkpoint: true\n",
|
|
" time_since_restore: 7.17207407951355\n",
|
|
" time_this_iter_s: 7.17207407951355\n",
|
|
" time_total_s: 7.17207407951355\n",
|
|
" timestamp: 1652957097\n",
|
|
" timesteps_since_restore: 0\n",
|
|
" train-error: 0.020101\n",
|
|
" train-logloss: 0.465715\n",
|
|
" training_iteration: 1\n",
|
|
" trial_id: b273b_00000\n",
|
|
" valid-error: 0.052632\n",
|
|
" valid-logloss: 0.480831\n",
|
|
" warmup_time: 0.003935098648071289\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001B[2m\u001B[36m(GBDTTrainable pid=11036)\u001B[0m 2022-05-19 11:44:59,796\tINFO main.py:1519 -- [RayXGBoost] Finished XGBoost training on training data with total N=398 in 6.80 seconds (2.92 pure XGBoost training time).\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Result for XGBoostTrainer_b273b_00000:\n",
|
|
" date: 2022-05-19_11-44-59\n",
|
|
" done: true\n",
|
|
" experiment_id: 991235d8b76649398688695ca70a08e4\n",
|
|
" experiment_tag: '0'\n",
|
|
" hostname: Kais-MacBook-Pro.local\n",
|
|
" iterations_since_restore: 100\n",
|
|
" node_ip: 127.0.0.1\n",
|
|
" pid: 11036\n",
|
|
" should_checkpoint: true\n",
|
|
" time_since_restore: 9.03934907913208\n",
|
|
" time_this_iter_s: 0.018042802810668945\n",
|
|
" time_total_s: 9.03934907913208\n",
|
|
" timestamp: 1652957099\n",
|
|
" timesteps_since_restore: 0\n",
|
|
" train-error: 0.0\n",
|
|
" train-logloss: 0.005949\n",
|
|
" training_iteration: 100\n",
|
|
" trial_id: b273b_00000\n",
|
|
" valid-error: 0.017544\n",
|
|
" valid-logloss: 0.07483\n",
|
|
" warmup_time: 0.003935098648071289\n",
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2022-05-19 11:45:00,535\tINFO tune.py:753 -- Total run time: 15.30 seconds (13.91 seconds for the tuning loop).\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'train-logloss': 0.005949, 'train-error': 0.0, 'valid-logloss': 0.07483, 'valid-error': 0.017544, 'time_this_iter_s': 0.018042802810668945, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 100, 'trial_id': 'b273b_00000', 'experiment_id': '991235d8b76649398688695ca70a08e4', 'date': '2022-05-19_11-44-59', 'timestamp': 1652957099, 'time_total_s': 9.03934907913208, 'pid': 11036, 'hostname': 'Kais-MacBook-Pro.local', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 9.03934907913208, 'timesteps_since_restore': 0, 'iterations_since_restore': 100, 'warmup_time': 0.003935098648071289, 'experiment_tag': '0'}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"result = train_xgboost(num_workers=2, use_gpu=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7055ad1b",
|
|
"metadata": {},
|
|
"source": [
|
|
"And perform inference on the obtained model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "283b1dba",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.96s/it]\n",
|
|
"Map_Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 87.81it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"PREDICTED LABELS\n",
|
|
" predictions\n",
|
|
"0 0\n",
|
|
"1 0\n",
|
|
"2 1\n",
|
|
"3 1\n",
|
|
"4 0\n",
|
|
".. ...\n",
|
|
"166 1\n",
|
|
"167 1\n",
|
|
"168 0\n",
|
|
"169 1\n",
|
|
"170 0\n",
|
|
"\n",
|
|
"[171 rows x 1 columns]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.78s/it]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"SHAP VALUES\n",
|
|
" predictions_0 predictions_1 predictions_2 predictions_3 \\\n",
|
|
"0 -0.139882 -0.748878 0.0 -1.143079 \n",
|
|
"1 0.013840 -1.053747 0.0 0.361219 \n",
|
|
"2 -0.082575 0.952107 0.0 0.396908 \n",
|
|
"3 0.016314 0.916166 0.0 0.535740 \n",
|
|
"4 -0.087534 1.317693 0.0 -0.631737 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 0.016314 1.006091 0.0 0.535740 \n",
|
|
"167 0.010002 0.948294 0.0 0.529942 \n",
|
|
"168 -0.084481 0.766085 0.0 -0.582221 \n",
|
|
"169 0.010002 0.846374 0.0 0.502846 \n",
|
|
"170 -0.108186 -1.032712 0.0 -0.737255 \n",
|
|
"\n",
|
|
" predictions_4 predictions_5 predictions_6 predictions_7 \\\n",
|
|
"0 0.228545 0.074653 -0.033109 -1.680274 \n",
|
|
"1 -0.386373 0.030964 -0.026341 -1.796480 \n",
|
|
"2 0.294464 0.142708 0.151952 1.859482 \n",
|
|
"3 0.224681 -0.013640 0.062032 0.909347 \n",
|
|
"4 -0.123310 -0.008267 -0.081633 -1.907682 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 0.224681 -0.013640 0.062032 0.890978 \n",
|
|
"167 -0.107441 0.143260 0.062032 1.149335 \n",
|
|
"168 -0.164466 0.088426 -0.081633 -1.767637 \n",
|
|
"169 -0.112530 0.029944 -0.074865 0.963479 \n",
|
|
"170 -0.250381 0.034186 -0.033109 -1.654185 \n",
|
|
"\n",
|
|
" predictions_8 predictions_9 ... predictions_21 predictions_22 \\\n",
|
|
"0 -0.173504 -0.027610 ... -0.373735 -1.117443 \n",
|
|
"1 0.153518 0.018295 ... -0.798841 0.277471 \n",
|
|
"2 0.153518 0.029338 ... 1.314059 -0.455756 \n",
|
|
"3 0.153518 0.015659 ... 0.816392 0.683619 \n",
|
|
"4 -0.173504 0.009200 ... 1.207632 -0.945986 \n",
|
|
".. ... ... ... ... ... \n",
|
|
"166 -0.173504 0.015659 ... 0.856858 0.704448 \n",
|
|
"167 0.153518 0.010089 ... 1.203512 0.708437 \n",
|
|
"168 0.153518 0.014880 ... -0.418931 -1.201489 \n",
|
|
"169 0.153518 0.010089 ... 1.211174 0.600757 \n",
|
|
"170 0.153518 0.016329 ... -0.556651 -1.009517 \n",
|
|
"\n",
|
|
" predictions_23 predictions_24 predictions_25 predictions_26 \\\n",
|
|
"0 -1.207984 0.349734 0.018222 -0.725013 \n",
|
|
"1 0.075934 -0.990557 -0.012509 -0.863824 \n",
|
|
"2 0.137665 0.668639 -0.042249 -0.684045 \n",
|
|
"3 0.766776 0.575949 0.022816 1.013024 \n",
|
|
"4 -0.577419 -0.454616 0.051755 -0.861906 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 0.754576 0.573718 0.022816 0.948516 \n",
|
|
"167 1.066871 0.487933 0.056155 -0.601421 \n",
|
|
"168 -1.310177 -0.386367 0.018222 -0.837832 \n",
|
|
"169 1.009837 0.694783 -0.042249 -0.626939 \n",
|
|
"170 -1.149971 -0.386467 -0.006737 -0.750287 \n",
|
|
"\n",
|
|
" predictions_27 predictions_28 predictions_29 predictions_30 \n",
|
|
"0 -1.149301 0.374839 0.0 1.046286 \n",
|
|
"1 -2.501725 -0.492608 0.0 1.046286 \n",
|
|
"2 0.077563 -0.106669 0.0 1.046286 \n",
|
|
"3 0.757272 0.341423 0.0 1.046286 \n",
|
|
"4 -0.800213 0.400311 0.0 1.046286 \n",
|
|
".. ... ... ... ... \n",
|
|
"166 0.757272 0.061695 0.0 1.046286 \n",
|
|
"167 0.610080 -0.339797 0.0 1.046286 \n",
|
|
"168 -1.300907 -0.474622 0.0 1.046286 \n",
|
|
"169 0.238948 -0.361304 0.0 1.046286 \n",
|
|
"170 -1.241549 -0.370570 0.0 1.046286 \n",
|
|
"\n",
|
|
"[171 rows x 31 columns]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predict_xgboost(result)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |