# Training a model with distributed LightGBM
In this example we will train a model in Ray AIR using distributed LightGBM.

Let's start with installing our dependencies:

In [1]:
!pip install -qU "ray[tune]" lightgbm_ray

Then we need some imports:

In [4]:
from typing import Tuple

import ray
from ray.train.batch_predictor import BatchPredictor
from ray.train.lightgbm import LightGBMPredictor
from ray.data.preprocessors.chain import Chain
from ray.data.preprocessors.encoder import Categorizer
from ray.train.lightgbm import LightGBMTrainer
from ray.air.config import ScalingConfig
from ray.data.dataset import Dataset
from ray.air.result import Result
from ray.data.preprocessors import StandardScaler

Next we define a function to load our train, validation, and test datasets.

In [13]:
def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:
    dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer_with_categorical.csv")
    train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)
    test_dataset = valid_dataset.map_batches(lambda df: df.drop("target", axis=1), batch_format="pandas")
    return train_dataset, valid_dataset, test_dataset

The following function will create a LightGBM trainer, train it, and return the result.

In [14]:
def train_lightgbm(num_workers: int, use_gpu: bool = False) -> Result:
    train_dataset, valid_dataset, _ = prepare_data()

    # Scale some random columns, and categorify the categorical_column,
    # allowing LightGBM to use its built-in categorical feature support
    columns_to_scale = ["mean radius", "mean texture"]
    preprocessor = Chain(
        Categorizer(["categorical_column"]), StandardScaler(columns=columns_to_scale)
    )

    # LightGBM specific params
    params = {
        "objective": "binary",
        "metric": ["binary_logloss", "binary_error"],
    }

    trainer = LightGBMTrainer(
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
        label_column="target",
        params=params,
        datasets={"train": train_dataset, "valid": valid_dataset},
        preprocessor=preprocessor,
        num_boost_round=100,
    )
    result = trainer.fit()
    print(result.metrics)

    return result

Once we have the result, we can do batch inference on the obtained model. Let's define a utility function for this.

In [15]:
def predict_lightgbm(result: Result):
    _, _, test_dataset = prepare_data()
    batch_predictor = BatchPredictor.from_checkpoint(
        result.checkpoint, LightGBMPredictor
    )

    predicted_labels = (
        batch_predictor.predict(test_dataset)
        .map_batches(lambda df: (df > 0.5).astype(int), batch_format="pandas")
    )
    print(f"PREDICTED LABELS")
    predicted_labels.show()

    shap_values = batch_predictor.predict(test_dataset, pred_contrib=True)
    print(f"SHAP VALUES")
    shap_values.show()

Now we can run the training:

In [16]:
result = train_lightgbm(num_workers=2, use_gpu=False)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 46.26it/s]


Trial name,status,loc,iter,total time (s),train-binary_logloss,train-binary_error,valid-binary_logloss
LightGBMTrainer_7b049_00000,TERMINATED,172.31.43.110:1491578,100,10.9726,0.000574522,0,0.171898


[2m[36m(pid=1491578)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491578)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491578)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491578)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(pid=1491651)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(pid=1491653)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491653)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491653)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491653)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(pid=1491652)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491652)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64In

[2m[36m(_RemoteRayLightGBMActor pid=1491653)[0m [LightGBM] [Info] Trying to bind port 59039...
[2m[36m(_RemoteRayLightGBMActor pid=1491653)[0m [LightGBM] [Info] Binding port 59039 succeeded
[2m[36m(_RemoteRayLightGBMActor pid=1491653)[0m [LightGBM] [Info] Listening...
[2m[36m(_RemoteRayLightGBMActor pid=1491652)[0m [LightGBM] [Info] Trying to bind port 46955...
[2m[36m(_RemoteRayLightGBMActor pid=1491652)[0m [LightGBM] [Info] Binding port 46955 succeeded
[2m[36m(_RemoteRayLightGBMActor pid=1491652)[0m [LightGBM] [Info] Listening...




[2m[36m(_RemoteRayLightGBMActor pid=1491653)[0m [LightGBM] [Info] Connected to rank 0
[2m[36m(_RemoteRayLightGBMActor pid=1491653)[0m [LightGBM] [Info] Local rank: 1, total number of machines: 2
[2m[36m(_RemoteRayLightGBMActor pid=1491652)[0m [LightGBM] [Info] Connected to rank 1
[2m[36m(_RemoteRayLightGBMActor pid=1491652)[0m [LightGBM] [Info] Local rank: 0, total number of machines: 2


[2m[36m(_QueueActor pid=1491650)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(_QueueActor pid=1491650)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(_QueueActor pid=1491650)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(_QueueActor pid=1491650)[0m   from pandas import MultiIndex, Int64Index


Result for LightGBMTrainer_7b049_00000:
  date: 2022-06-22_17-26-53
  done: false
  experiment_id: b4a87c26a7604a43baf895755d4f16b3
  hostname: ip-172-31-43-110
  iterations_since_restore: 1
  node_ip: 172.31.43.110
  pid: 1491578
  should_checkpoint: true
  time_since_restore: 8.369545459747314
  time_this_iter_s: 8.369545459747314
  time_total_s: 8.369545459747314
  timestamp: 1655918813
  timesteps_since_restore: 0
  train-binary_error: 0.5175879396984925
  train-binary_logloss: 0.6302848981539763
  training_iteration: 1
  trial_id: 7b049_00000
  valid-binary_error: 0.2
  valid-binary_logloss: 0.558752017793943
  warmup_time: 0.008721590042114258
  
Result for LightGBMTrainer_7b049_00000:
  date: 2022-06-22_17-26-56
  done: true
  experiment_id: b4a87c26a7604a43baf895755d4f16b3
  experiment_tag: '0'
  hostname: ip-172-31-43-110
  iterations_since_restore: 100
  node_ip: 172.31.43.110
  pid: 1491578
  should_checkpoint: true
  time_since_restore: 10.972588300704956
  time_this_iter_s

2022-06-22 17:26:56,406	INFO tune.py:734 -- Total run time: 14.73 seconds (14.06 seconds for the tuning loop).


{'train-binary_logloss': 0.0005745220956391456, 'train-binary_error': 0.0, 'valid-binary_logloss': 0.17189847605331432, 'valid-binary_error': 0.058823529411764705, 'time_this_iter_s': 0.027977466583251953, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 100, 'trial_id': '7b049_00000', 'experiment_id': 'b4a87c26a7604a43baf895755d4f16b3', 'date': '2022-06-22_17-26-56', 'timestamp': 1655918816, 'time_total_s': 10.972588300704956, 'pid': 1491578, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 10.972588300704956, 'timesteps_since_restore': 0, 'iterations_since_restore': 100, 'warmup_time': 0.008721590042114258, 'experiment_tag': '0'}


And perform inference on the obtained model:

In [17]:
predict_lightgbm(result)

Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 50.96it/s]
[2m[36m(pid=1491998)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491998)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491998)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1491998)[0m   from pandas import MultiIndex, Int64Index
Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:02<00:00,  2.05s/it]
Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 75.07it/s]


PREDICTED LABELS
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}


[2m[36m(pid=1492031)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492031)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492031)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492031)[0m   from pandas import MultiIndex, Int64Index
[2m[36m(pid=1492033)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492033)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492033)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492033)[0m   from pandas import MultiIndex, Int64Index
Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:02<00:00,  2.09s/it]


SHAP VALUES
{'predictions_0': 0.006121974664714535, 'predictions_1': 0.8940294162424869, 'predictions_2': -0.013623909529011522, 'predictions_3': -0.26580572803883, 'predictions_4': 0.2897686828261492, 'predictions_5': -0.03784232120648852, 'predictions_6': 0.021865334852359534, 'predictions_7': 1.1753326094382734, 'predictions_8': -0.02525466292349231, 'predictions_9': 0.0733463992354119, 'predictions_10': 0.09191922035401615, 'predictions_11': -0.0035196096494634313, 'predictions_12': 0.20211476104388482, 'predictions_13': 0.7813488658944929, 'predictions_14': 0.10000464816891827, 'predictions_15': 0.11543593649642907, 'predictions_16': -0.009732477634862284, 'predictions_17': 0.19117650484758314, 'predictions_18': -0.17600075102817322, 'predictions_19': 0.5829434737180024, 'predictions_20': 1.4220773445509465, 'predictions_21': 0.6086211783805069, 'predictions_22': 2.0031654232526925, 'predictions_23': 0.3090376110779834, 'predictions_24': -0.21156467772251453, 'predictions_25': 0.1

[2m[36m(pid=1492090)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492090)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492090)[0m   _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
[2m[36m(pid=1492090)[0m   from pandas import MultiIndex, Int64Index
