diff --git a/python/ray/data/preprocessors/encoder.py b/python/ray/data/preprocessors/encoder.py index 05955cdca..20bbb7e11 100644 --- a/python/ray/data/preprocessors/encoder.py +++ b/python/ray/data/preprocessors/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional from collections import Counter, OrderedDict import pandas as pd @@ -320,27 +320,27 @@ class Categorizer(Preprocessor): of data leakage when using this preprocessor. Args: - columns: The columns whose data type to change. Can be - either a list of columns, in which case the categories - will be inferred automatically from the data, or - a dict of `column:pd.CategoricalDtype or None` - - if specified, the dtype will be applied, and if not, - it will be automatically inferred. + columns: The columns to change to `pd.CategoricalDtype`. + dtypes: An optional dictionary that maps columns to `pd.CategoricalDtype` + objects. If you don't include a column in `dtypes`, then the categories + will be inferred. """ def __init__( - self, columns: Union[List[str], Dict[str, Optional[pd.CategoricalDtype]]] + self, + columns: List[str], + dtypes: Optional[Dict[str, pd.CategoricalDtype]] = None, ): + if not dtypes: + dtypes = {} + self.columns = columns + self.dtypes = dtypes def _fit(self, dataset: Dataset) -> Preprocessor: - columns_to_get = ( - self.columns - if isinstance(self.columns, list) - else [ - column for column, cat_type in self.columns.items() if cat_type is None - ] - ) + columns_to_get = [ + column for column in self.columns if column not in self.dtypes + ] if columns_to_get: unique_indices = _get_unique_value_indices( dataset, columns_to_get, drop_na_values=True, key_format="{0}" @@ -351,8 +351,7 @@ class Categorizer(Preprocessor): } else: unique_indices = {} - if isinstance(self.columns, dict): - unique_indices = {**self.columns, **unique_indices} + unique_indices = {**self.dtypes, **unique_indices} self.stats_: Dict[str, pd.CategoricalDtype] = unique_indices return self @@ -362,7 +361,9 @@ class Categorizer(Preprocessor): def __repr__(self): stats = getattr(self, "stats_", None) - return f"" + return ( + f"" + ) def _get_unique_value_indices( diff --git a/python/ray/data/tests/test_preprocessors.py b/python/ray/data/tests/test_preprocessors.py index f7e38beba..9a356429e 100644 --- a/python/ray/data/tests/test_preprocessors.py +++ b/python/ray/data/tests/test_preprocessors.py @@ -828,23 +828,22 @@ def test_categorizer(predefined_dtypes): in_df = pd.DataFrame.from_dict({"A": col_a, "B": col_b, "C": col_c}) ds = ray.data.from_pandas(in_df) + columns = ["B", "C"] if predefined_dtypes: expected_dtypes = { "B": pd.CategoricalDtype(["cold", "hot", "warm"], ordered=True), "C": pd.CategoricalDtype([1, 5, 10]), } - columns = { - "B": pd.CategoricalDtype(["cold", "hot", "warm"], ordered=True), - "C": None, - } + dtypes = {"B": pd.CategoricalDtype(["cold", "hot", "warm"], ordered=True)} else: expected_dtypes = { "B": pd.CategoricalDtype(["cold", "hot", "warm"]), "C": pd.CategoricalDtype([1, 5, 10]), } columns = ["B", "C"] + dtypes = None - encoder = Categorizer(columns) + encoder = Categorizer(columns, dtypes) # Transform with unfitted preprocessor. with pytest.raises(PreprocessorNotFittedException):