[AIR] Change Categorizer signature (#26980)

This commit is contained in:
Balaji Veeramani 2022-07-26 00:35:36 -07:00 committed by GitHub
parent b2b11316cd
commit 262ce1acef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 23 deletions

View file

@ -1,5 +1,5 @@
from functools import partial
from typing import List, Dict, Optional, Union
from typing import List, Dict, Optional
from collections import Counter, OrderedDict
import pandas as pd
@ -320,27 +320,27 @@ class Categorizer(Preprocessor):
of data leakage when using this preprocessor.
Args:
columns: The columns whose data type to change. Can be
either a list of columns, in which case the categories
will be inferred automatically from the data, or
a dict of `column:pd.CategoricalDtype or None` -
if specified, the dtype will be applied, and if not,
it will be automatically inferred.
columns: The columns to change to `pd.CategoricalDtype`.
dtypes: An optional dictionary that maps columns to `pd.CategoricalDtype`
objects. If you don't include a column in `dtypes`, then the categories
will be inferred.
"""
def __init__(
self, columns: Union[List[str], Dict[str, Optional[pd.CategoricalDtype]]]
self,
columns: List[str],
dtypes: Optional[Dict[str, pd.CategoricalDtype]] = None,
):
if not dtypes:
dtypes = {}
self.columns = columns
self.dtypes = dtypes
def _fit(self, dataset: Dataset) -> Preprocessor:
columns_to_get = (
self.columns
if isinstance(self.columns, list)
else [
column for column, cat_type in self.columns.items() if cat_type is None
]
)
columns_to_get = [
column for column in self.columns if column not in self.dtypes
]
if columns_to_get:
unique_indices = _get_unique_value_indices(
dataset, columns_to_get, drop_na_values=True, key_format="{0}"
@ -351,8 +351,7 @@ class Categorizer(Preprocessor):
}
else:
unique_indices = {}
if isinstance(self.columns, dict):
unique_indices = {**self.columns, **unique_indices}
unique_indices = {**self.dtypes, **unique_indices}
self.stats_: Dict[str, pd.CategoricalDtype] = unique_indices
return self
@ -362,7 +361,9 @@ class Categorizer(Preprocessor):
def __repr__(self):
stats = getattr(self, "stats_", None)
return f"<Categorizer columns={self.columns} stats={stats}>"
return (
f"<Categorizer columns={self.columns} dtypes={self.dtypes} stats={stats}>"
)
def _get_unique_value_indices(

View file

@ -828,23 +828,22 @@ def test_categorizer(predefined_dtypes):
in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c})
ds = ray.data.from_pandas(in_df)
columns = ["B", "C"]
if predefined_dtypes:
expected_dtypes = {
"B": pd.CategoricalDtype(["cold", "hot", "warm"], ordered=True),
"C": pd.CategoricalDtype([1, 5, 10]),
}
columns = {
"B": pd.CategoricalDtype(["cold", "hot", "warm"], ordered=True),
"C": None,
}
dtypes = {"B": pd.CategoricalDtype(["cold", "hot", "warm"], ordered=True)}
else:
expected_dtypes = {
"B": pd.CategoricalDtype(["cold", "hot", "warm"]),
"C": pd.CategoricalDtype([1, 5, 10]),
}
columns = ["B", "C"]
dtypes = None
encoder = Categorizer(columns)
encoder = Categorizer(columns, dtypes)
# Transform with unfitted preprocessor.
with pytest.raises(PreprocessorNotFittedException):