mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air preprocessor] Add limit to OHE. (#24893)
This commit is contained in:
parent
da5cf93d97
commit
8703d5e9d0
3 changed files with 94 additions and 24 deletions
|
@ -1,5 +1,6 @@
|
||||||
from typing import List, Dict, Optional, Union
|
from typing import List, Dict, Optional, Union
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from ray.data import Dataset
|
from ray.data import Dataset
|
||||||
|
@ -50,19 +51,50 @@ class OneHotEncoder(Preprocessor):
|
||||||
for each of the values from the fitted dataset. The value of a column will
|
for each of the values from the fitted dataset. The value of a column will
|
||||||
be set to 1 if the value matches, otherwise 0.
|
be set to 1 if the value matches, otherwise 0.
|
||||||
|
|
||||||
Transforming values not included in the fitted dataset will result in all
|
Transforming values not included in the fitted dataset or not among
|
||||||
of the encoded column values being 0.
|
the top popular values (see ``limit``) will result in all of the encoded column
|
||||||
|
values being 0.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
ohe = OneHotEncoder(
|
||||||
|
columns=[
|
||||||
|
"trip_start_hour",
|
||||||
|
"trip_start_day",
|
||||||
|
"trip_start_month",
|
||||||
|
"dropoff_census_tract",
|
||||||
|
"pickup_community_area",
|
||||||
|
"dropoff_community_area",
|
||||||
|
"payment_type",
|
||||||
|
"company",
|
||||||
|
],
|
||||||
|
limit={
|
||||||
|
"dropoff_census_tract": 25,
|
||||||
|
"pickup_community_area": 20,
|
||||||
|
"dropoff_community_area": 20,
|
||||||
|
"payment_type": 2,
|
||||||
|
"company": 7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
columns: The columns that will individually be encoded.
|
columns: The columns that will individually be encoded.
|
||||||
|
limit: If set, only the top "limit" number of most popular values become
|
||||||
|
categorical variables. The less frequent ones will result in all
|
||||||
|
the encoded column values being 0. This is a dict of column to
|
||||||
|
its corresponding limit. The column in this dictionary has to be
|
||||||
|
in ``columns``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, columns: List[str]):
|
def __init__(self, columns: List[str], limit: Optional[Dict[str, int]] = None):
|
||||||
# TODO: add `drop` parameter.
|
# TODO: add `drop` parameter.
|
||||||
self.columns = columns
|
self.columns = columns
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
def _fit(self, dataset: Dataset) -> Preprocessor:
|
def _fit(self, dataset: Dataset) -> Preprocessor:
|
||||||
self.stats_ = _get_unique_value_indices(dataset, self.columns)
|
self.stats_ = _get_unique_value_indices(dataset, self.columns, limit=self.limit)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _transform_pandas(self, df: pd.DataFrame):
|
def _transform_pandas(self, df: pd.DataFrame):
|
||||||
|
@ -177,35 +209,59 @@ def _get_unique_value_indices(
|
||||||
columns: List[str],
|
columns: List[str],
|
||||||
drop_na_values: bool = False,
|
drop_na_values: bool = False,
|
||||||
key_format: str = "unique_values({0})",
|
key_format: str = "unique_values({0})",
|
||||||
|
limit: Optional[Dict[str, int]] = None,
|
||||||
) -> Dict[str, Dict[str, int]]:
|
) -> Dict[str, Dict[str, int]]:
|
||||||
"""If drop_na_values is True, will silently drop NA values."""
|
"""If drop_na_values is True, will silently drop NA values."""
|
||||||
|
limit = limit or {}
|
||||||
|
for column in limit:
|
||||||
|
if column not in columns:
|
||||||
|
raise ValueError(
|
||||||
|
f"You set limit for {column}, which is not present in {columns}."
|
||||||
|
)
|
||||||
|
|
||||||
def get_pd_unique_values(df: pd.DataFrame) -> List[Dict[str, set]]:
|
def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]:
|
||||||
return [{col: set(df[col].unique()) for col in columns}]
|
result = [
|
||||||
|
{
|
||||||
|
col: Counter(df[col].value_counts(dropna=False).to_dict())
|
||||||
|
for col in columns
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
uniques = dataset.map_batches(get_pd_unique_values, batch_format="pandas")
|
value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
|
||||||
final_uniques = {col: set() for col in columns}
|
final_counters = {col: Counter() for col in columns}
|
||||||
for batch in uniques.iter_batches():
|
for batch in value_counts.iter_batches():
|
||||||
for col_uniques in batch:
|
for col_value_counts in batch:
|
||||||
for col, uniques in col_uniques.items():
|
for col, value_counts in col_value_counts.items():
|
||||||
final_uniques[col].update(uniques)
|
final_counters[col] += value_counts
|
||||||
|
|
||||||
for col, uniques in final_uniques.items():
|
# Inspect if there is any NA values.
|
||||||
|
for col in columns:
|
||||||
if drop_na_values:
|
if drop_na_values:
|
||||||
final_uniques[col] = {v for v in uniques if not pd.isnull(v)}
|
counter = final_counters[col]
|
||||||
|
counter_dict = dict(counter)
|
||||||
|
sanitized_dict = {k: v for k, v in counter_dict.items() if not pd.isnull(k)}
|
||||||
|
final_counters[col] = Counter(sanitized_dict)
|
||||||
else:
|
else:
|
||||||
if any(pd.isnull(v) for v in uniques):
|
if any(pd.isnull(k) for k in final_counters[col]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to fit column '{col}' because it contains null values. "
|
f"Unable to fit column '{col}' because it contains null"
|
||||||
f"Consider imputing missing values first."
|
f" values. Consider imputing missing values first."
|
||||||
)
|
)
|
||||||
|
|
||||||
unique_values_with_indices = {
|
unique_values_with_indices = dict()
|
||||||
key_format.format(column): {
|
for column in columns:
|
||||||
k: j for j, k in enumerate(sorted(final_uniques[column]))
|
if column in limit:
|
||||||
}
|
# Output sorted by freq.
|
||||||
for column in columns
|
unique_values_with_indices[key_format.format(column)] = {
|
||||||
}
|
k[0]: j
|
||||||
|
for j, k in enumerate(final_counters[column].most_common(limit[column]))
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Output sorted by column name.
|
||||||
|
unique_values_with_indices[key_format.format(column)] = {
|
||||||
|
k: j for j, k in enumerate(sorted(dict(final_counters[column]).keys()))
|
||||||
|
}
|
||||||
return unique_values_with_indices
|
return unique_values_with_indices
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -470,6 +470,20 @@ def test_one_hot_encoder():
|
||||||
null_encoder.transform_batch(nonnull_df)
|
null_encoder.transform_batch(nonnull_df)
|
||||||
|
|
||||||
|
|
||||||
|
def test_one_hot_encoder_with_limit():
|
||||||
|
"""Tests basic OneHotEncoder functionality with limit."""
|
||||||
|
col_a = ["red", "green", "blue", "red"]
|
||||||
|
col_b = ["warm", "cold", "hot", "cold"]
|
||||||
|
col_c = [1, 10, 5, 10]
|
||||||
|
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
|
||||||
|
ds = ray.data.from_pandas(in_df)
|
||||||
|
|
||||||
|
encoder = OneHotEncoder(["B", "C"], limit={"B": 2})
|
||||||
|
|
||||||
|
ds_out = encoder.fit_transform(ds)
|
||||||
|
assert len(ds_out.to_pandas().columns) == 1 + 2 + 3
|
||||||
|
|
||||||
|
|
||||||
def test_label_encoder():
|
def test_label_encoder():
|
||||||
"""Tests basic LabelEncoder functionality."""
|
"""Tests basic LabelEncoder functionality."""
|
||||||
col_a = ["red", "green", "blue", "red"]
|
col_a = ["red", "green", "blue", "red"]
|
||||||
|
|
|
@ -35,7 +35,7 @@ gym==0.19.0; python_version < '3.7'
|
||||||
lz4
|
lz4
|
||||||
scikit-image
|
scikit-image
|
||||||
pandas>=1.0.5; python_version < '3.7'
|
pandas>=1.0.5; python_version < '3.7'
|
||||||
pandas>=1.2.0; python_version >= '3.7'
|
pandas>=1.3.0; python_version >= '3.7'
|
||||||
scipy==1.4.1
|
scipy==1.4.1
|
||||||
tabulate
|
tabulate
|
||||||
tensorboardX >= 1.9
|
tensorboardX >= 1.9
|
||||||
|
|
Loading…
Add table
Reference in a new issue