[AIR] Refactor _get_unique_value_indices (#24144)

Refactors _get_unique_value_indices (used in Encoder preprocessors) for much improved performance with multiple columns. Also uses the same, more robust intermediary dataset format in _get_most_frequent_values (Imputers).

The existing unit tests pass, and no functionality has been changed.
This commit is contained in:
Antoni Baum 2022-04-28 22:39:04 +02:00 committed by GitHub
parent ba14f0a41b
commit e62d3fac74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 34 deletions

View file

@ -1,4 +1,4 @@
from typing import List, Dict, Set
from typing import List, Dict
import pandas as pd
@ -118,37 +118,40 @@ class LabelEncoder(Preprocessor):
def _get_unique_value_indices(
dataset: Dataset, *columns: str
dataset: Dataset,
*columns: str,
drop_na_values: bool = False,
) -> Dict[str, Dict[str, int]]:
results = {}
for column in columns:
values = _get_unique_values(dataset, column)
if any(pd.isnull(v) for v in values):
raise ValueError(
f"Unable to fit column '{column}' because it contains null values. "
f"Consider imputing missing values first."
)
value_to_index = _sorted_value_indices(values)
results[f"unique_values({column})"] = value_to_index
return results
"""If drop_na_values is True, will silently drop NA values."""
columns = list(columns)
def get_pd_unique_values(df: pd.DataFrame):
return [{col: set(df[col].unique()) for col in columns}]
def _get_unique_values(dataset: Dataset, column: str) -> Set[str]:
agg_ds = dataset.groupby(column).count()
# TODO: Support an upper limit by using `agg_ds.take(N)` instead.
return {row[column] for row in agg_ds.iter_rows()}
uniques = dataset.map_batches(get_pd_unique_values, batch_format="pandas")
final_uniques = {col: set() for col in columns}
for batch in uniques.iter_batches():
for col_uniques in batch:
for col, uniques in col_uniques.items():
final_uniques[col].update(uniques)
for col, uniques in final_uniques.items():
if drop_na_values:
final_uniques[col] = {v for v in uniques if not pd.isnull(v)}
else:
if any(pd.isnull(v) for v in uniques):
raise ValueError(
f"Unable to fit column '{col}' because it contains null values. "
f"Consider imputing missing values first."
)
def _sorted_value_indices(values: Set) -> Dict[str, int]:
"""Converts values to a Dict mapping to unique indexes.
Values will be sorted.
Example:
>>> _sorted_value_indices({"b", "a", "c", "a"})
{"a": 0, "b": 1, "c": 2}
"""
return {value: i for i, value in enumerate(sorted(values))}
unique_values_with_indices = {
f"unique_values({column})": {
k: j for j, k in enumerate(sorted(final_uniques[column]))
}
for column in columns
}
return unique_values_with_indices
def _validate_df(df: pd.DataFrame, *columns: str) -> None:

View file

@ -88,16 +88,17 @@ def _get_most_frequent_values(
) -> Dict[str, Union[str, Number]]:
columns = list(columns)
def get_pd_value_counts(df: pd.DataFrame) -> List[Counter]:
return [Counter(df[col].value_counts().to_dict()) for col in columns]
def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]:
return [{col: Counter(df[col].value_counts().to_dict()) for col in columns}]
value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
final_counters = [Counter() for _ in columns]
final_counters = {col: Counter() for col in columns}
for batch in value_counts.iter_batches():
for i, col_value_counts in enumerate(batch):
final_counters[i] += col_value_counts
for col_value_counts in batch:
for col, value_counts in col_value_counts.items():
final_counters[col] += value_counts
return {
f"most_frequent({column})": final_counters[i].most_common(1)[0][0]
for i, column in enumerate(columns)
f"most_frequent({column})": final_counters[column].most_common(1)[0][0]
for column in columns
}