mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[air preprocessor] Add limit to OHE. (#24893)
This commit is contained in:
parent
da5cf93d97
commit
8703d5e9d0
3 changed files with 94 additions and 24 deletions
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Dict, Optional, Union
|
||||
|
||||
from collections import Counter
|
||||
import pandas as pd
|
||||
|
||||
from ray.data import Dataset
|
||||
|
@ -50,19 +51,50 @@ class OneHotEncoder(Preprocessor):
|
|||
for each of the values from the fitted dataset. The value of a column will
|
||||
be set to 1 if the value matches, otherwise 0.
|
||||
|
||||
Transforming values not included in the fitted dataset will result in all
|
||||
of the encoded column values being 0.
|
||||
Transforming values not included in the fitted dataset or not among
|
||||
the top popular values (see ``limit``) will result in all of the encoded column
|
||||
values being 0.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ohe = OneHotEncoder(
|
||||
columns=[
|
||||
"trip_start_hour",
|
||||
"trip_start_day",
|
||||
"trip_start_month",
|
||||
"dropoff_census_tract",
|
||||
"pickup_community_area",
|
||||
"dropoff_community_area",
|
||||
"payment_type",
|
||||
"company",
|
||||
],
|
||||
limit={
|
||||
"dropoff_census_tract": 25,
|
||||
"pickup_community_area": 20,
|
||||
"dropoff_community_area": 20,
|
||||
"payment_type": 2,
|
||||
"company": 7,
|
||||
},
|
||||
)
|
||||
|
||||
Args:
|
||||
columns: The columns that will individually be encoded.
|
||||
limit: If set, only the top "limit" number of most popular values become
|
||||
categorical variables. The less frequent ones will result in all
|
||||
the encoded column values being 0. This is a dict of column to
|
||||
its corresponding limit. The column in this dictionary has to be
|
||||
in ``columns``.
|
||||
"""
|
||||
|
||||
def __init__(self, columns: List[str]):
|
||||
def __init__(self, columns: List[str], limit: Optional[Dict[str, int]] = None):
|
||||
# TODO: add `drop` parameter.
|
||||
self.columns = columns
|
||||
self.limit = limit
|
||||
|
||||
def _fit(self, dataset: Dataset) -> Preprocessor:
|
||||
self.stats_ = _get_unique_value_indices(dataset, self.columns)
|
||||
self.stats_ = _get_unique_value_indices(dataset, self.columns, limit=self.limit)
|
||||
return self
|
||||
|
||||
def _transform_pandas(self, df: pd.DataFrame):
|
||||
|
@ -177,35 +209,59 @@ def _get_unique_value_indices(
|
|||
columns: List[str],
|
||||
drop_na_values: bool = False,
|
||||
key_format: str = "unique_values({0})",
|
||||
limit: Optional[Dict[str, int]] = None,
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
"""If drop_na_values is True, will silently drop NA values."""
|
||||
limit = limit or {}
|
||||
for column in limit:
|
||||
if column not in columns:
|
||||
raise ValueError(
|
||||
f"You set limit for {column}, which is not present in {columns}."
|
||||
)
|
||||
|
||||
def get_pd_unique_values(df: pd.DataFrame) -> List[Dict[str, set]]:
|
||||
return [{col: set(df[col].unique()) for col in columns}]
|
||||
def get_pd_value_counts(df: pd.DataFrame) -> List[Dict[str, Counter]]:
|
||||
result = [
|
||||
{
|
||||
col: Counter(df[col].value_counts(dropna=False).to_dict())
|
||||
for col in columns
|
||||
}
|
||||
]
|
||||
return result
|
||||
|
||||
uniques = dataset.map_batches(get_pd_unique_values, batch_format="pandas")
|
||||
final_uniques = {col: set() for col in columns}
|
||||
for batch in uniques.iter_batches():
|
||||
for col_uniques in batch:
|
||||
for col, uniques in col_uniques.items():
|
||||
final_uniques[col].update(uniques)
|
||||
value_counts = dataset.map_batches(get_pd_value_counts, batch_format="pandas")
|
||||
final_counters = {col: Counter() for col in columns}
|
||||
for batch in value_counts.iter_batches():
|
||||
for col_value_counts in batch:
|
||||
for col, value_counts in col_value_counts.items():
|
||||
final_counters[col] += value_counts
|
||||
|
||||
for col, uniques in final_uniques.items():
|
||||
# Inspect if there is any NA values.
|
||||
for col in columns:
|
||||
if drop_na_values:
|
||||
final_uniques[col] = {v for v in uniques if not pd.isnull(v)}
|
||||
counter = final_counters[col]
|
||||
counter_dict = dict(counter)
|
||||
sanitized_dict = {k: v for k, v in counter_dict.items() if not pd.isnull(k)}
|
||||
final_counters[col] = Counter(sanitized_dict)
|
||||
else:
|
||||
if any(pd.isnull(v) for v in uniques):
|
||||
if any(pd.isnull(k) for k in final_counters[col]):
|
||||
raise ValueError(
|
||||
f"Unable to fit column '{col}' because it contains null values. "
|
||||
f"Consider imputing missing values first."
|
||||
f"Unable to fit column '{col}' because it contains null"
|
||||
f" values. Consider imputing missing values first."
|
||||
)
|
||||
|
||||
unique_values_with_indices = {
|
||||
key_format.format(column): {
|
||||
k: j for j, k in enumerate(sorted(final_uniques[column]))
|
||||
}
|
||||
for column in columns
|
||||
}
|
||||
unique_values_with_indices = dict()
|
||||
for column in columns:
|
||||
if column in limit:
|
||||
# Output sorted by freq.
|
||||
unique_values_with_indices[key_format.format(column)] = {
|
||||
k[0]: j
|
||||
for j, k in enumerate(final_counters[column].most_common(limit[column]))
|
||||
}
|
||||
else:
|
||||
# Output sorted by column name.
|
||||
unique_values_with_indices[key_format.format(column)] = {
|
||||
k: j for j, k in enumerate(sorted(dict(final_counters[column]).keys()))
|
||||
}
|
||||
return unique_values_with_indices
|
||||
|
||||
|
||||
|
|
|
@ -470,6 +470,20 @@ def test_one_hot_encoder():
|
|||
null_encoder.transform_batch(nonnull_df)
|
||||
|
||||
|
||||
def test_one_hot_encoder_with_limit():
|
||||
"""Tests basic OneHotEncoder functionality with limit."""
|
||||
col_a = ["red", "green", "blue", "red"]
|
||||
col_b = ["warm", "cold", "hot", "cold"]
|
||||
col_c = [1, 10, 5, 10]
|
||||
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
|
||||
ds = ray.data.from_pandas(in_df)
|
||||
|
||||
encoder = OneHotEncoder(["B", "C"], limit={"B": 2})
|
||||
|
||||
ds_out = encoder.fit_transform(ds)
|
||||
assert len(ds_out.to_pandas().columns) == 1 + 2 + 3
|
||||
|
||||
|
||||
def test_label_encoder():
|
||||
"""Tests basic LabelEncoder functionality."""
|
||||
col_a = ["red", "green", "blue", "red"]
|
||||
|
|
|
@ -35,7 +35,7 @@ gym==0.19.0; python_version < '3.7'
|
|||
lz4
|
||||
scikit-image
|
||||
pandas>=1.0.5; python_version < '3.7'
|
||||
pandas>=1.2.0; python_version >= '3.7'
|
||||
pandas>=1.3.0; python_version >= '3.7'
|
||||
scipy==1.4.1
|
||||
tabulate
|
||||
tensorboardX >= 1.9
|
||||
|
|
Loading…
Add table
Reference in a new issue