mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR] Refactor _get_unique_value_indices
(#24144)
Refactors _get_unique_value_indices (used in Encoder preprocessors) for much improved performance with multiple columns. Also uses the same, more robust intermediary dataset format in _get_most_frequent_values (Imputers). The existing unit tests pass, and no functionality has been changed.
This commit is contained in:
parent
ba14f0a41b
commit
e62d3fac74
2 changed files with 38 additions and 34 deletions
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Dict, Set
|
||||
from typing import List, Dict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
@ -118,37 +118,40 @@ class LabelEncoder(Preprocessor):
|
|||
|
||||
|
||||
def _get_unique_value_indices(
|
||||
dataset: Dataset, *columns: str
|
||||
dataset: Dataset,
|
||||
*columns: str,
|
||||
drop_na_values: bool = False,
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
results = {}
|
||||
for column in columns:
|
||||
values = _get_unique_values(dataset, column)
|
||||
if any(pd.isnull(v) for v in values):
|
||||
raise ValueError(
|
||||
f"Unable to fit column '{column}' because it contains null values. "
|
||||
f"Consider imputing missing values first."
|
||||
)
|
||||
value_to_index = _sorted_value_indices(values)
|
||||
results[f"unique_values({column})"] = value_to_index
|
||||
return results
|
||||
"""If drop_na_values is True, will silently drop NA values."""
|
||||
columns = list(columns)
|
||||
|
||||
def get_pd_unique_values(df: pd.DataFrame):
|
||||
return [{col: set(df[col].unique()) for col in columns}]
|
||||
|
||||
def _get_unique_values(dataset: Dataset, column: str) -> Set[str]:
|
||||
agg_ds = dataset.groupby(column).count()
|
||||
# TODO: Support an upper limit by using `agg_ds.take(N)` instead.
|
||||
return {row[column] for row in agg_ds.iter_rows()}
|
||||
uniques = dataset.map_batches(get_pd_unique_values, batch_format="pandas")
|
||||
final_uniques = {col: set() for col in columns}
|
||||
for batch in uniques.iter_batches():
|
||||
for col_uniques in batch:
|
||||
for col, uniques in col_uniques.items():
|
||||
final_uniques[col].update(uniques)
|
||||
|
||||
for col, uniques in final_uniques.items():
|
||||
if drop_na_values:
|
||||
final_uniques[col] = {v for v in uniques if not pd.isnull(v)}
|
||||
else:
|
||||
if any(pd.isnull(v) for v in uniques):
|
||||
raise ValueError(
|
||||
f"Unable to fit column '{col}' because it contains null values. "
|
||||
f"Consider imputing missing values first."
|
||||
)
|
||||
|
||||
def _sorted_value_indices(values: Set) -> Dict[str, int]:
|
||||
"""Converts values to a Dict mapping to unique indexes.
|
||||
|
||||
Values will be sorted.
|
||||
|
||||
Example:
|
||||
>>> _sorted_value_indices({"b", "a", "c", "a"})
|
||||
{"a": 0, "b": 1, "c": 2}
|
||||
"""
|
||||
return {value: i for i, value in enumerate(sorted(values))}
|
||||
unique_values_with_indices = {
|
||||
f"unique_values({column})": {
|
||||
k: j for j, k in enumerate(sorted(final_uniques[column]))
|
||||
}
|
||||
for column in columns
|
||||
}
|
||||
return unique_values_with_indices
|
||||
|
||||
|
||||
def _validate_df(df: pd.DataFrame, *columns: str) -> None:
|
||||
|
|
|
@ -88,16 +88,17 @@ def _get_most_frequent_values(
|
|||
) -> Dict[str, Union[str, Number]]:
|
||||
columns = list(columns)
|
||||
|
||||
def get_pd_value_counts(df: pd.DataFrame) -> List[Counter]:
|
||||
return [Counter(df[col].value_counts().to_dict()) for col in columns]
|
||||
def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]:
|
||||
return [{col: Counter(df[col].value_counts().to_dict()) for col in columns}]
|
||||
|
||||
value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
|
||||
final_counters = [Counter() for _ in columns]
|
||||
final_counters = {col: Counter() for col in columns}
|
||||
for batch in value_counts.iter_batches():
|
||||
for i, col_value_counts in enumerate(batch):
|
||||
final_counters[i] += col_value_counts
|
||||
for col_value_counts in batch:
|
||||
for col, value_counts in col_value_counts.items():
|
||||
final_counters[col] += value_counts
|
||||
|
||||
return {
|
||||
f"most_frequent({column})": final_counters[i].most_common(1)[0][0]
|
||||
for i, column in enumerate(columns)
|
||||
f"most_frequent({column})": final_counters[column].most_common(1)[0][0]
|
||||
for column in columns
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue