[air preprocessor] Add limit to OHE. (#24893)

This commit is contained in:
xwjiang2010 2022-05-23 22:37:15 -07:00 committed by GitHub
parent da5cf93d97
commit 8703d5e9d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 24 deletions

View file

@ -1,5 +1,6 @@
from typing import List, Dict, Optional, Union from typing import List, Dict, Optional, Union
from collections import Counter
import pandas as pd import pandas as pd
from ray.data import Dataset from ray.data import Dataset
@ -50,19 +51,50 @@ class OneHotEncoder(Preprocessor):
for each of the values from the fitted dataset. The value of a column will for each of the values from the fitted dataset. The value of a column will
be set to 1 if the value matches, otherwise 0. be set to 1 if the value matches, otherwise 0.
Transforming values not included in the fitted dataset will result in all Transforming values not included in the fitted dataset or not among
of the encoded column values being 0. the top popular values (see ``limit``) will result in all of the encoded column
values being 0.
Example:
.. code-block:: python
ohe = OneHotEncoder(
columns=[
"trip_start_hour",
"trip_start_day",
"trip_start_month",
"dropoff_census_tract",
"pickup_community_area",
"dropoff_community_area",
"payment_type",
"company",
],
limit={
"dropoff_census_tract": 25,
"pickup_community_area": 20,
"dropoff_community_area": 20,
"payment_type": 2,
"company": 7,
},
)
Args: Args:
columns: The columns that will individually be encoded. columns: The columns that will individually be encoded.
limit: If set, only the top "limit" number of most popular values become
categorical variables. The less frequent ones will result in all
the encoded column values being 0. This is a dict of column to
its corresponding limit. The column in this dictionary has to be
in ``columns``.
""" """
def __init__(self, columns: List[str]): def __init__(self, columns: List[str], limit: Optional[Dict[str, int]] = None):
# TODO: add `drop` parameter. # TODO: add `drop` parameter.
self.columns = columns self.columns = columns
self.limit = limit
def _fit(self, dataset: Dataset) -> Preprocessor: def _fit(self, dataset: Dataset) -> Preprocessor:
self.stats_ = _get_unique_value_indices(dataset, self.columns) self.stats_ = _get_unique_value_indices(dataset, self.columns, limit=self.limit)
return self return self
def _transform_pandas(self, df: pd.DataFrame): def _transform_pandas(self, df: pd.DataFrame):
@ -177,35 +209,59 @@ def _get_unique_value_indices(
columns: List[str], columns: List[str],
drop_na_values: bool = False, drop_na_values: bool = False,
key_format: str = "unique_values({0})", key_format: str = "unique_values({0})",
limit: Optional[Dict[str, int]] = None,
) -> Dict[str, Dict[str, int]]: ) -> Dict[str, Dict[str, int]]:
"""If drop_na_values is True, will silently drop NA values.""" """If drop_na_values is True, will silently drop NA values."""
limit = limit or {}
for column in limit:
if column not in columns:
raise ValueError(
f"You set limit for {column}, which is not present in {columns}."
)
def get_pd_unique_values(df: pd.DataFrame) -> List[Dict[str, set]]: def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]:
return [{col: set(df[col].unique()) for col in columns}] result = [
{
col: Counter(df[col].value_counts(dropna=False).to_dict())
for col in columns
}
]
return result
uniques = dataset.map_batches(get_pd_unique_values, batch_format="pandas") value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
final_uniques = {col: set() for col in columns} final_counters = {col: Counter() for col in columns}
for batch in uniques.iter_batches(): for batch in value_counts.iter_batches():
for col_uniques in batch: for col_value_counts in batch:
for col, uniques in col_uniques.items(): for col, value_counts in col_value_counts.items():
final_uniques[col].update(uniques) final_counters[col] += value_counts
for col, uniques in final_uniques.items(): # Inspect if there is any NA values.
for col in columns:
if drop_na_values: if drop_na_values:
final_uniques[col] = {v for v in uniques if not pd.isnull(v)} counter = final_counters[col]
counter_dict = dict(counter)
sanitized_dict = {k: v for k, v in counter_dict.items() if not pd.isnull(k)}
final_counters[col] = Counter(sanitized_dict)
else: else:
if any(pd.isnull(v) for v in uniques): if any(pd.isnull(k) for k in final_counters[col]):
raise ValueError( raise ValueError(
f"Unable to fit column '{col}' because it contains null values. " f"Unable to fit column '{col}' because it contains null"
f"Consider imputing missing values first." f" values. Consider imputing missing values first."
) )
unique_values_with_indices = { unique_values_with_indices = dict()
key_format.format(column): { for column in columns:
k: j for j, k in enumerate(sorted(final_uniques[column])) if column in limit:
} # Output sorted by freq.
for column in columns unique_values_with_indices[key_format.format(column)] = {
} k[0]: j
for j, k in enumerate(final_counters[column].most_common(limit[column]))
}
else:
# Output sorted by column name.
unique_values_with_indices[key_format.format(column)] = {
k: j for j, k in enumerate(sorted(dict(final_counters[column]).keys()))
}
return unique_values_with_indices return unique_values_with_indices

View file

@ -470,6 +470,20 @@ def test_one_hot_encoder():
null_encoder.transform_batch(nonnull_df) null_encoder.transform_batch(nonnull_df)
def test_one_hot_encoder_with_limit():
"""Tests basic OneHotEncoder functionality with limit."""
col_a = ["red", "green", "blue", "red"]
col_b = ["warm", "cold", "hot", "cold"]
col_c = [1, 10, 5, 10]
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
ds = ray.data.from_pandas(in_df)
encoder = OneHotEncoder(["B", "C"], limit={"B": 2})
ds_out = encoder.fit_transform(ds)
assert len(ds_out.to_pandas().columns) == 1 + 2 + 3
def test_label_encoder(): def test_label_encoder():
"""Tests basic LabelEncoder functionality.""" """Tests basic LabelEncoder functionality."""
col_a = ["red", "green", "blue", "red"] col_a = ["red", "green", "blue", "red"]

View file

@ -35,7 +35,7 @@ gym==0.19.0; python_version < '3.7'
lz4 lz4
scikit-image scikit-image
pandas>=1.0.5; python_version < '3.7' pandas>=1.0.5; python_version < '3.7'
pandas>=1.2.0; python_version >= '3.7' pandas>=1.3.0; python_version >= '3.7'
scipy==1.4.1 scipy==1.4.1
tabulate tabulate
tensorboardX >= 1.9 tensorboardX >= 1.9