mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Datasets] Support ignoring NaNs in aggregations. (#20787)
Adds support for ignoring NaNs in aggregations. NaNs will now be ignored by default, and the user can pass in `ds.mean("A", ignore_nulls=False)` if they would rather have the NaN be propagated to the output. Specifically, we'd have the following null-handling semantics: 1. Mix of values and nulls - `ignore_nulls`=True: Ignore the nulls, return aggregation of values 2. Mix of values and nulls - `ignore_nulls`=False: Return `None` 3. All nulls: Return `None` 4. Empty dataset: Return `None` This all null and empty dataset handling matches the semantics of NumPy and Pandas.
This commit is contained in:
parent
f0d8b6d701
commit
f264cf800a
5 changed files with 892 additions and 111 deletions
|
@ -3,6 +3,12 @@ from typing import Callable, Optional, List, TYPE_CHECKING
|
|||
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.data.block import T, U, KeyType, AggType, KeyFn, _validate_key_fn
|
||||
from ray.data.impl.null_aggregate import (
|
||||
_null_wrap_init,
|
||||
_null_wrap_accumulate,
|
||||
_null_wrap_merge,
|
||||
_null_wrap_finalize,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset
|
||||
|
@ -75,13 +81,15 @@ class Count(AggregateFn):
|
|||
class Sum(_AggregateOnKeyBase):
|
||||
"""Defines sum aggregation."""
|
||||
|
||||
def __init__(self, on: Optional[KeyFn] = None):
|
||||
def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
|
||||
self._set_key_fn(on)
|
||||
on_fn = _to_on_fn(on)
|
||||
|
||||
super().__init__(
|
||||
init=lambda k: 0,
|
||||
accumulate=lambda a, r: a + on_fn(r),
|
||||
merge=lambda a1, a2: a1 + a2,
|
||||
init=_null_wrap_init(lambda k: 0),
|
||||
accumulate=_null_wrap_accumulate(ignore_nulls, on_fn, lambda a, r: a + r),
|
||||
merge=_null_wrap_merge(ignore_nulls, lambda a1, a2: a1 + a2),
|
||||
finalize=_null_wrap_finalize(lambda a: a),
|
||||
name=(f"sum({str(on)})"),
|
||||
)
|
||||
|
||||
|
@ -90,13 +98,15 @@ class Sum(_AggregateOnKeyBase):
|
|||
class Min(_AggregateOnKeyBase):
|
||||
"""Defines min aggregation."""
|
||||
|
||||
def __init__(self, on: Optional[KeyFn] = None):
|
||||
def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
|
||||
self._set_key_fn(on)
|
||||
on_fn = _to_on_fn(on)
|
||||
|
||||
super().__init__(
|
||||
init=lambda k: None,
|
||||
accumulate=(lambda a, r: (on_fn(r) if a is None else min(a, on_fn(r)))),
|
||||
merge=lambda a1, a2: min(a1, a2),
|
||||
init=_null_wrap_init(lambda k: float("inf")),
|
||||
accumulate=_null_wrap_accumulate(ignore_nulls, on_fn, min),
|
||||
merge=_null_wrap_merge(ignore_nulls, min),
|
||||
finalize=_null_wrap_finalize(lambda a: a),
|
||||
name=(f"min({str(on)})"),
|
||||
)
|
||||
|
||||
|
@ -105,13 +115,15 @@ class Min(_AggregateOnKeyBase):
|
|||
class Max(_AggregateOnKeyBase):
|
||||
"""Defines max aggregation."""
|
||||
|
||||
def __init__(self, on: Optional[KeyFn] = None):
|
||||
def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
|
||||
self._set_key_fn(on)
|
||||
on_fn = _to_on_fn(on)
|
||||
|
||||
super().__init__(
|
||||
init=lambda k: None,
|
||||
accumulate=(lambda a, r: (on_fn(r) if a is None else max(a, on_fn(r)))),
|
||||
merge=lambda a1, a2: max(a1, a2),
|
||||
init=_null_wrap_init(lambda k: float("-inf")),
|
||||
accumulate=_null_wrap_accumulate(ignore_nulls, on_fn, max),
|
||||
merge=_null_wrap_merge(ignore_nulls, max),
|
||||
finalize=_null_wrap_finalize(lambda a: a),
|
||||
name=(f"max({str(on)})"),
|
||||
)
|
||||
|
||||
|
@ -120,14 +132,19 @@ class Max(_AggregateOnKeyBase):
|
|||
class Mean(_AggregateOnKeyBase):
|
||||
"""Defines mean aggregation."""
|
||||
|
||||
def __init__(self, on: Optional[KeyFn] = None):
|
||||
def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
|
||||
self._set_key_fn(on)
|
||||
on_fn = _to_on_fn(on)
|
||||
|
||||
super().__init__(
|
||||
init=lambda k: [0, 0],
|
||||
accumulate=lambda a, r: [a[0] + on_fn(r), a[1] + 1],
|
||||
merge=lambda a1, a2: [a1[0] + a2[0], a1[1] + a2[1]],
|
||||
finalize=lambda a: a[0] / a[1],
|
||||
init=_null_wrap_init(lambda k: [0, 0]),
|
||||
accumulate=_null_wrap_accumulate(
|
||||
ignore_nulls, on_fn, lambda a, r: [a[0] + r, a[1] + 1]
|
||||
),
|
||||
merge=_null_wrap_merge(
|
||||
ignore_nulls, lambda a1, a2: [a1[0] + a2[0], a1[1] + a2[1]]
|
||||
),
|
||||
finalize=_null_wrap_finalize(lambda a: a[0] / a[1]),
|
||||
name=(f"mean({str(on)})"),
|
||||
)
|
||||
|
||||
|
@ -145,7 +162,12 @@ class Std(_AggregateOnKeyBase):
|
|||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
"""
|
||||
|
||||
def __init__(self, on: Optional[KeyFn] = None, ddof: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
on: Optional[KeyFn] = None,
|
||||
ddof: int = 1,
|
||||
ignore_nulls: bool = True,
|
||||
):
|
||||
self._set_key_fn(on)
|
||||
on_fn = _to_on_fn(on)
|
||||
|
||||
|
@ -153,14 +175,11 @@ class Std(_AggregateOnKeyBase):
|
|||
# Accumulates the current count, the current mean, and the sum of
|
||||
# squared differences from the current mean (M2).
|
||||
M2, mean, count = a
|
||||
# Select the data on which we want to calculate the standard
|
||||
# deviation.
|
||||
val = on_fn(r)
|
||||
|
||||
count += 1
|
||||
delta = val - mean
|
||||
delta = r - mean
|
||||
mean += delta / count
|
||||
delta2 = val - mean
|
||||
delta2 = r - mean
|
||||
M2 += delta * delta2
|
||||
return [M2, mean, count]
|
||||
|
||||
|
@ -190,10 +209,10 @@ class Std(_AggregateOnKeyBase):
|
|||
return math.sqrt(M2 / (count - ddof))
|
||||
|
||||
super().__init__(
|
||||
init=lambda k: [0, 0, 0],
|
||||
accumulate=accumulate,
|
||||
merge=merge,
|
||||
finalize=finalize,
|
||||
init=_null_wrap_init(lambda k: [0, 0, 0]),
|
||||
accumulate=_null_wrap_accumulate(ignore_nulls, on_fn, accumulate),
|
||||
merge=_null_wrap_merge(ignore_nulls, merge),
|
||||
finalize=_null_wrap_finalize(finalize),
|
||||
name=(f"std({str(on)})"),
|
||||
)
|
||||
|
||||
|
|
|
@ -562,7 +562,7 @@ class Dataset(Generic[T]):
|
|||
return Dataset(new_blocks, self._epoch, stats.build_multistage(stage_info))
|
||||
|
||||
def split(
|
||||
self, n: int, *, equal: bool = False, locality_hints: List[Any] = None
|
||||
self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None
|
||||
) -> List["Dataset[T]"]:
|
||||
"""Split the dataset into ``n`` disjoint pieces.
|
||||
|
||||
|
@ -975,7 +975,7 @@ class Dataset(Generic[T]):
|
|||
LazyBlockList(calls, metadata, block_partitions), max_epoch, dataset_stats
|
||||
)
|
||||
|
||||
def groupby(self, key: KeyFn) -> "GroupedDataset[T]":
|
||||
def groupby(self, key: Optional[KeyFn]) -> "GroupedDataset[T]":
|
||||
"""Group the dataset by the key function or column name.
|
||||
|
||||
This is a lazy operation.
|
||||
|
@ -1034,7 +1034,9 @@ class Dataset(Generic[T]):
|
|||
ret = self.groupby(None).aggregate(*aggs).take(1)
|
||||
return ret[0] if len(ret) > 0 else None
|
||||
|
||||
def sum(self, on: Union[KeyFn, List[KeyFn]] = None) -> U:
|
||||
def sum(
|
||||
self, on: Optional[Union[KeyFn, List[KeyFn]]] = None, ignore_nulls: bool = True
|
||||
) -> U:
|
||||
"""Compute sum over entire dataset.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -1057,6 +1059,11 @@ class Dataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to return an ``ArrowRow``
|
||||
containing the column-wise sum of all columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the sum; if ``False``,
|
||||
if a null value is encountered, the output will be None.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The sum result.
|
||||
|
@ -1079,15 +1086,15 @@ class Dataset(Generic[T]):
|
|||
- ``on=["col_1", ..., "col_n"]``: an n-column ``ArrowRow``
|
||||
containing the column-wise sum of the provided columns.
|
||||
|
||||
If the dataset is empty, then the output is 0.
|
||||
If the dataset is empty, all values are null, or any value is null
|
||||
AND ``ignore_nulls`` is ``False``, then the output will be None.
|
||||
"""
|
||||
ret = self._aggregate_on(Sum, on)
|
||||
if ret is None:
|
||||
return 0
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
ret = self._aggregate_on(Sum, on, ignore_nulls)
|
||||
return self._aggregate_result(ret)
|
||||
|
||||
def min(self, on: Union[KeyFn, List[KeyFn]] = None) -> U:
|
||||
def min(
|
||||
self, on: Optional[Union[KeyFn, List[KeyFn]]] = None, ignore_nulls: bool = True
|
||||
) -> U:
|
||||
"""Compute minimum over entire dataset.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -1110,6 +1117,11 @@ class Dataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to return an ``ArrowRow``
|
||||
containing the column-wise min of all columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the min; if ``False``,
|
||||
if a null value is encountered, the output will be None.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The min result.
|
||||
|
@ -1132,15 +1144,15 @@ class Dataset(Generic[T]):
|
|||
- ``on=["col_1", ..., "col_n"]``: an n-column ``ArrowRow``
|
||||
containing the column-wise min of the provided columns.
|
||||
|
||||
If the dataset is empty, then a ``ValueError`` is raised.
|
||||
If the dataset is empty, all values are null, or any value is null
|
||||
AND ``ignore_nulls`` is ``False``, then the output will be None.
|
||||
"""
|
||||
ret = self._aggregate_on(Min, on)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute min on an empty dataset")
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
ret = self._aggregate_on(Min, on, ignore_nulls)
|
||||
return self._aggregate_result(ret)
|
||||
|
||||
def max(self, on: Union[KeyFn, List[KeyFn]] = None) -> U:
|
||||
def max(
|
||||
self, on: Optional[Union[KeyFn, List[KeyFn]]] = None, ignore_nulls: bool = True
|
||||
) -> U:
|
||||
"""Compute maximum over entire dataset.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -1163,6 +1175,11 @@ class Dataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to return an ``ArrowRow``
|
||||
containing the column-wise max of all columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the max; if ``False``,
|
||||
if a null value is encountered, the output will be None.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The max result.
|
||||
|
@ -1185,15 +1202,15 @@ class Dataset(Generic[T]):
|
|||
- ``on=["col_1", ..., "col_n"]``: an n-column ``ArrowRow``
|
||||
containing the column-wise max of the provided columns.
|
||||
|
||||
If the dataset is empty, then a ``ValueError`` is raised.
|
||||
If the dataset is empty, all values are null, or any value is null
|
||||
AND ``ignore_nulls`` is ``False``, then the output will be None.
|
||||
"""
|
||||
ret = self._aggregate_on(Max, on)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute max on an empty dataset")
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
ret = self._aggregate_on(Max, on, ignore_nulls)
|
||||
return self._aggregate_result(ret)
|
||||
|
||||
def mean(self, on: Union[KeyFn, List[KeyFn]] = None) -> U:
|
||||
def mean(
|
||||
self, on: Optional[Union[KeyFn, List[KeyFn]]] = None, ignore_nulls: bool = True
|
||||
) -> U:
|
||||
"""Compute mean over entire dataset.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -1216,6 +1233,11 @@ class Dataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to return an ``ArrowRow``
|
||||
containing the column-wise mean of all columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the mean; if ``False``,
|
||||
if a null value is encountered, the output will be None.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The mean result.
|
||||
|
@ -1238,15 +1260,18 @@ class Dataset(Generic[T]):
|
|||
- ``on=["col_1", ..., "col_n"]``: an n-column ``ArrowRow``
|
||||
containing the column-wise mean of the provided columns.
|
||||
|
||||
If the dataset is empty, then a ``ValueError`` is raised.
|
||||
If the dataset is empty, all values are null, or any value is null
|
||||
AND ``ignore_nulls`` is ``False``, then the output will be None.
|
||||
"""
|
||||
ret = self._aggregate_on(Mean, on)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute mean on an empty dataset")
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
ret = self._aggregate_on(Mean, on, ignore_nulls)
|
||||
return self._aggregate_result(ret)
|
||||
|
||||
def std(self, on: Union[KeyFn, List[KeyFn]] = None, ddof: int = 1) -> U:
|
||||
def std(
|
||||
self,
|
||||
on: Optional[Union[KeyFn, List[KeyFn]]] = None,
|
||||
ddof: int = 1,
|
||||
ignore_nulls: bool = True,
|
||||
) -> U:
|
||||
"""Compute standard deviation over entire dataset.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -1279,6 +1304,11 @@ class Dataset(Generic[T]):
|
|||
containing the column-wise std of all columns.
|
||||
ddof: Delta Degrees of Freedom. The divisor used in calculations
|
||||
is ``N - ddof``, where ``N`` represents the number of elements.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the std; if ``False``,
|
||||
if a null value is encountered, the output will be None.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The standard deviation result.
|
||||
|
@ -1301,15 +1331,15 @@ class Dataset(Generic[T]):
|
|||
- ``on=["col_1", ..., "col_n"]``: an n-column ``ArrowRow``
|
||||
containing the column-wise std of the provided columns.
|
||||
|
||||
If the dataset is empty, then a ``ValueError`` is raised.
|
||||
If the dataset is empty, all values are null, or any value is null
|
||||
AND ``ignore_nulls`` is ``False``, then the output will be None.
|
||||
"""
|
||||
ret = self._aggregate_on(Std, on, ddof=ddof)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute std on an empty dataset")
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
ret = self._aggregate_on(Std, on, ignore_nulls, ddof=ddof)
|
||||
return self._aggregate_result(ret)
|
||||
|
||||
def sort(self, key: KeyFn = None, descending: bool = False) -> "Dataset[T]":
|
||||
def sort(
|
||||
self, key: Optional[KeyFn] = None, descending: bool = False
|
||||
) -> "Dataset[T]":
|
||||
"""Sort the dataset by the specified key column or key function.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -1864,7 +1894,7 @@ class Dataset(Generic[T]):
|
|||
self,
|
||||
*,
|
||||
prefetch_blocks: int = 0,
|
||||
batch_size: int = None,
|
||||
batch_size: Optional[int] = None,
|
||||
batch_format: str = "native",
|
||||
drop_last: bool = False,
|
||||
) -> Iterator[BatchType]:
|
||||
|
@ -1953,12 +1983,12 @@ class Dataset(Generic[T]):
|
|||
self,
|
||||
*,
|
||||
label_column: Optional[str] = None,
|
||||
feature_columns: Union[
|
||||
None, List[str], List[List[str]], Dict[str, List[str]]
|
||||
feature_columns: Optional[
|
||||
Union[List[str], List[List[str]], Dict[str, List[str]]]
|
||||
] = None,
|
||||
label_column_dtype: Optional["torch.dtype"] = None,
|
||||
feature_column_dtypes: Union[
|
||||
None, "torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]
|
||||
feature_column_dtypes: Optional[
|
||||
Union["torch.dtype", List["torch.dtype"], Dict[str, "torch.dtype"]]
|
||||
] = None,
|
||||
batch_size: int = 1,
|
||||
prefetch_blocks: int = 0,
|
||||
|
@ -2403,7 +2433,7 @@ Dict[str, List[str]]]): The names of the columns
|
|||
block_to_arrow = cached_remote_fn(_block_to_arrow)
|
||||
return [block_to_arrow.remote(block) for block in blocks]
|
||||
|
||||
def repeat(self, times: int = None) -> "DatasetPipeline[T]":
|
||||
def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
|
||||
"""Convert this into a DatasetPipeline by looping over this dataset.
|
||||
|
||||
Transformations prior to the call to ``repeat()`` are evaluated once.
|
||||
|
@ -2665,7 +2695,7 @@ Dict[str, List[str]]]): The names of the columns
|
|||
return "simple"
|
||||
|
||||
def _aggregate_on(
|
||||
self, agg_cls: type, on: Union[KeyFn, List[KeyFn]], *args, **kwargs
|
||||
self, agg_cls: type, on: Optional[Union[KeyFn, List[KeyFn]]], *args, **kwargs
|
||||
):
|
||||
"""Helper for aggregating on a particular subset of the dataset.
|
||||
|
||||
|
@ -2680,9 +2710,10 @@ Dict[str, List[str]]]): The names of the columns
|
|||
def _build_multicolumn_aggs(
|
||||
self,
|
||||
agg_cls: type,
|
||||
on: Union[KeyFn, List[KeyFn]],
|
||||
skip_cols: Optional[List[str]] = None,
|
||||
on: Optional[Union[KeyFn, List[KeyFn]]],
|
||||
ignore_nulls: bool,
|
||||
*args,
|
||||
skip_cols: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Build set of aggregations for applying a single aggregation to
|
||||
|
@ -2706,10 +2737,10 @@ Dict[str, List[str]]]): The names of the columns
|
|||
|
||||
if not isinstance(on, list):
|
||||
on = [on]
|
||||
return [agg_cls(on_, *args, **kwargs) for on_ in on]
|
||||
return [agg_cls(on_, *args, ignore_nulls=ignore_nulls, **kwargs) for on_ in on]
|
||||
|
||||
def _aggregate_result(self, result: Union[Tuple, TableRow]) -> U:
|
||||
if len(result) == 1:
|
||||
if result is not None and len(result) == 1:
|
||||
if isinstance(result, tuple):
|
||||
return result[0]
|
||||
else:
|
||||
|
|
|
@ -119,7 +119,12 @@ class GroupedDataset(Generic[T]):
|
|||
)
|
||||
|
||||
def _aggregate_on(
|
||||
self, agg_cls: type, on: Union[KeyFn, List[KeyFn]], *args, **kwargs
|
||||
self,
|
||||
agg_cls: type,
|
||||
on: Union[KeyFn, List[KeyFn]],
|
||||
ignore_nulls: bool,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Helper for aggregating on a particular subset of the dataset.
|
||||
|
||||
|
@ -129,7 +134,7 @@ class GroupedDataset(Generic[T]):
|
|||
aggregation on the entire row for a simple Dataset.
|
||||
"""
|
||||
aggs = self._dataset._build_multicolumn_aggs(
|
||||
agg_cls, on, *args, skip_cols=self._key, **kwargs
|
||||
agg_cls, on, ignore_nulls, *args, skip_cols=self._key, **kwargs
|
||||
)
|
||||
return self.aggregate(*aggs)
|
||||
|
||||
|
@ -152,7 +157,9 @@ class GroupedDataset(Generic[T]):
|
|||
"""
|
||||
return self.aggregate(Count())
|
||||
|
||||
def sum(self, on: Union[KeyFn, List[KeyFn]] = None) -> Dataset[U]:
|
||||
def sum(
|
||||
self, on: Union[KeyFn, List[KeyFn]] = None, ignore_nulls: bool = True
|
||||
) -> Dataset[U]:
|
||||
"""Compute grouped sum aggregation.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -179,6 +186,11 @@ class GroupedDataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to do a column-wise sum of all
|
||||
columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the sum; if ``False``,
|
||||
if a null value is encountered, the output will be null.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The sum result.
|
||||
|
@ -203,9 +215,11 @@ class GroupedDataset(Generic[T]):
|
|||
|
||||
If groupby key is ``None`` then the key part of return is omitted.
|
||||
"""
|
||||
return self._aggregate_on(Sum, on)
|
||||
return self._aggregate_on(Sum, on, ignore_nulls)
|
||||
|
||||
def min(self, on: Union[KeyFn, List[KeyFn]] = None) -> Dataset[U]:
|
||||
def min(
|
||||
self, on: Union[KeyFn, List[KeyFn]] = None, ignore_nulls: bool = True
|
||||
) -> Dataset[U]:
|
||||
"""Compute grouped min aggregation.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -232,6 +246,11 @@ class GroupedDataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to do a column-wise min of all
|
||||
columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the min; if ``False``,
|
||||
if a null value is encountered, the output will be null.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The min result.
|
||||
|
@ -256,9 +275,11 @@ class GroupedDataset(Generic[T]):
|
|||
|
||||
If groupby key is ``None`` then the key part of return is omitted.
|
||||
"""
|
||||
return self._aggregate_on(Min, on)
|
||||
return self._aggregate_on(Min, on, ignore_nulls)
|
||||
|
||||
def max(self, on: Union[KeyFn, List[KeyFn]] = None) -> Dataset[U]:
|
||||
def max(
|
||||
self, on: Union[KeyFn, List[KeyFn]] = None, ignore_nulls: bool = True
|
||||
) -> Dataset[U]:
|
||||
"""Compute grouped max aggregation.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -285,6 +306,11 @@ class GroupedDataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to do a column-wise max of all
|
||||
columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the max; if ``False``,
|
||||
if a null value is encountered, the output will be null.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The max result.
|
||||
|
@ -309,9 +335,11 @@ class GroupedDataset(Generic[T]):
|
|||
|
||||
If groupby key is ``None`` then the key part of return is omitted.
|
||||
"""
|
||||
return self._aggregate_on(Max, on)
|
||||
return self._aggregate_on(Max, on, ignore_nulls)
|
||||
|
||||
def mean(self, on: Union[KeyFn, List[KeyFn]] = None) -> Dataset[U]:
|
||||
def mean(
|
||||
self, on: Union[KeyFn, List[KeyFn]] = None, ignore_nulls: bool = True
|
||||
) -> Dataset[U]:
|
||||
"""Compute grouped mean aggregation.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -338,6 +366,11 @@ class GroupedDataset(Generic[T]):
|
|||
- For an Arrow dataset: it can be a column name or a list
|
||||
thereof, and the default is to do a column-wise mean of all
|
||||
columns.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the mean; if ``False``,
|
||||
if a null value is encountered, the output will be null.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The mean result.
|
||||
|
@ -363,9 +396,14 @@ class GroupedDataset(Generic[T]):
|
|||
|
||||
If groupby key is ``None`` then the key part of return is omitted.
|
||||
"""
|
||||
return self._aggregate_on(Mean, on)
|
||||
return self._aggregate_on(Mean, on, ignore_nulls)
|
||||
|
||||
def std(self, on: Union[KeyFn, List[KeyFn]] = None, ddof: int = 1) -> Dataset[U]:
|
||||
def std(
|
||||
self,
|
||||
on: Union[KeyFn, List[KeyFn]] = None,
|
||||
ddof: int = 1,
|
||||
ignore_nulls: bool = True,
|
||||
) -> Dataset[U]:
|
||||
"""Compute grouped standard deviation aggregation.
|
||||
|
||||
This is a blocking operation.
|
||||
|
@ -402,6 +440,11 @@ class GroupedDataset(Generic[T]):
|
|||
columns.
|
||||
ddof: Delta Degrees of Freedom. The divisor used in calculations
|
||||
is ``N - ddof``, where ``N`` represents the number of elements.
|
||||
ignore_nulls: Whether to ignore null values. If ``True``, null
|
||||
values will be ignored when computing the std; if ``False``,
|
||||
if a null value is encountered, the output will be null.
|
||||
We consider np.nan, None, and pd.NaT to be null values.
|
||||
Default is ``True``.
|
||||
|
||||
Returns:
|
||||
The standard deviation result.
|
||||
|
@ -426,7 +469,7 @@ class GroupedDataset(Generic[T]):
|
|||
|
||||
If groupby key is ``None`` then the key part of return is omitted.
|
||||
"""
|
||||
return self._aggregate_on(Std, on, ddof=ddof)
|
||||
return self._aggregate_on(Std, on, ignore_nulls, ddof=ddof)
|
||||
|
||||
|
||||
def _partition_and_combine_block(
|
||||
|
|
235
python/ray/data/impl/null_aggregate.py
Normal file
235
python/ray/data/impl/null_aggregate.py
Normal file
|
@ -0,0 +1,235 @@
|
|||
from typing import Tuple, Callable, Any, Union
|
||||
from types import ModuleType
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.data.block import T, U, KeyType, AggType
|
||||
|
||||
|
||||
# This module contains aggregation helpers for handling nulls.
|
||||
# The null handling policy is:
|
||||
# 1. Mix of values and nulls - ignore_nulls=True: Ignore the nulls, return
|
||||
# aggregation of non-null values.
|
||||
# 2. Mix of values and nulls - ignore_nulls=False: Return None.
|
||||
# 3. All nulls: Return None.
|
||||
# 4. Empty dataset: Return None.
|
||||
#
|
||||
# This is accomplished by checking rows for null values and by propagating nulls
|
||||
# if found AND if we're not ignoring them. If not ignoring nulls, in order to delineate
|
||||
# between found null rows and an empty block accumulation when merging (the latter of
|
||||
# which we want to propagate; the former of which we do not), we attach a boolean flag
|
||||
# indicating whether or not an accumulation contains valid data to intermediate block
|
||||
# accumulations via _wrap_acc() and _unwrap_acc(). This allows us to properly merge
|
||||
# intermediate block accumulations under a streaming constraint.
|
||||
|
||||
|
||||
def _wrap_acc(a: AggType, has_data: bool) -> AggType:
|
||||
"""
|
||||
Wrap accumulation with a numeric boolean flag indicating whether or not
|
||||
this accumulation contains real data; if it doesn't, we consider it to be
|
||||
empty.
|
||||
|
||||
Args:
|
||||
a: The accumulation value.
|
||||
has_data: Whether the accumulation contains real data.
|
||||
|
||||
Returns:
|
||||
An AggType list with the last element being a numeric boolean flag indicating
|
||||
whether or not this accumulation contains real data. If the input a has length
|
||||
n, the returned AggType has length n + 1.
|
||||
"""
|
||||
if not isinstance(a, list):
|
||||
a = [a]
|
||||
return a + [1 if has_data else 0]
|
||||
|
||||
|
||||
def _unwrap_acc(a: AggType) -> Tuple[AggType, bool]:
|
||||
"""
|
||||
Unwrap the accumulation, which we assume has been wrapped (via _wrap_acc) with a
|
||||
numeric boolean flag indicating whether or not this accumulation contains real data.
|
||||
|
||||
Args:
|
||||
a: The wrapped accumulation value that we wish to unwrap.
|
||||
|
||||
Returns:
|
||||
A tuple containing the unwrapped accumulation value and a boolean indicating
|
||||
whether the accumulation contains real data.
|
||||
"""
|
||||
has_data = a[-1] == 1
|
||||
a = a[:-1]
|
||||
if len(a) == 1:
|
||||
a = a[0]
|
||||
return a, has_data
|
||||
|
||||
|
||||
def _null_wrap_init(init: Callable[[KeyType], AggType]) -> Callable[[KeyType], AggType]:
|
||||
"""
|
||||
Wraps an accumulation initializer with null handling.
|
||||
|
||||
The returned initializer function adds on a has_data field that the accumulator
|
||||
uses to track whether an aggregation is empty.
|
||||
|
||||
Args:
|
||||
init: The core init function to wrap.
|
||||
|
||||
Returns:
|
||||
A new accumulation initializer function that can handle nulls.
|
||||
"""
|
||||
|
||||
def _init(k: KeyType) -> AggType:
|
||||
a = init(k)
|
||||
# Initializing accumulation, so indicate that the accumulation doesn't represent
|
||||
# real data yet.
|
||||
return _wrap_acc(a, has_data=False)
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def _null_wrap_accumulate(
|
||||
ignore_nulls: bool,
|
||||
on_fn: Callable[[T], T],
|
||||
accum: Callable[[AggType, T], AggType],
|
||||
) -> Callable[[AggType, T], AggType]:
|
||||
"""
|
||||
Wrap accumulator function with null handling.
|
||||
|
||||
The returned accumulate function expects a to be either None or of the form:
|
||||
a = [acc_data_1, ..., acc_data_n, has_data].
|
||||
|
||||
This performs an accumulation subject to the following null rules:
|
||||
1. If r is null and ignore_nulls=False, return None.
|
||||
2. If r is null and ignore_nulls=True, return a.
|
||||
3. If r is non-null and a is None, return None.
|
||||
5. If r is non-null and a is non-None, return accum(a[:-1], r).
|
||||
|
||||
Args:
|
||||
ignore_nulls: Whether nulls should be ignored or cause a None result.
|
||||
on_fn: Function selecting a subset of the row to apply the aggregation.
|
||||
accum: The core accumulator function to wrap.
|
||||
|
||||
Returns:
|
||||
A new accumulator function that handles nulls.
|
||||
"""
|
||||
|
||||
def _accum(a: AggType, r: T) -> AggType:
|
||||
r = on_fn(r)
|
||||
if _is_null(r):
|
||||
if ignore_nulls:
|
||||
# Ignoring nulls, return the current accumulation, ignoring r.
|
||||
return a
|
||||
else:
|
||||
# Not ignoring nulls, so propagate the null.
|
||||
return None
|
||||
else:
|
||||
if a is None:
|
||||
# Accumulation is None so (1) a previous row must have been null, and
|
||||
# (2) we must be propagating nulls, so continue to pragate this null.
|
||||
return None
|
||||
else:
|
||||
# Row is non-null and accumulation is non-null, so we now apply the core
|
||||
# accumulation.
|
||||
a, _ = _unwrap_acc(a)
|
||||
a = accum(a, r)
|
||||
return _wrap_acc(a, has_data=True)
|
||||
|
||||
return _accum
|
||||
|
||||
|
||||
def _null_wrap_merge(
|
||||
ignore_nulls: bool,
|
||||
merge: Callable[[AggType, AggType], AggType],
|
||||
) -> AggType:
|
||||
"""
|
||||
Wrap merge function with null handling.
|
||||
|
||||
The returned merge function expects a1 and a2 to be either None or of the form:
|
||||
a = [acc_data_1, ..., acc_data_2, has_data].
|
||||
|
||||
This merges two accumulations subject to the following null rules:
|
||||
1. If a1 is empty and a2 is empty, return empty accumulation.
|
||||
2. If a1 (a2) is empty and a2 (a1) is None, return None.
|
||||
3. If a1 (a2) is empty and a2 (a1) is non-None, return a2 (a1).
|
||||
4. If a1 (a2) is None, return a2 (a1) if ignoring nulls, None otherwise.
|
||||
5. If a1 and a2 are both non-null, return merge(a1, a2).
|
||||
|
||||
Args:
|
||||
ignore_nulls: Whether nulls should be ignored or cause a None result.
|
||||
merge: The core merge function to wrap.
|
||||
|
||||
Returns:
|
||||
A new merge function that handles nulls.
|
||||
"""
|
||||
|
||||
def _merge(a1: AggType, a2: AggType) -> AggType:
|
||||
if a1 is None:
|
||||
# If we're ignoring nulls, propagate a2; otherwise, propagate None.
|
||||
return a2 if ignore_nulls else None
|
||||
unwrapped_a1, a1_has_data = _unwrap_acc(a1)
|
||||
if not a1_has_data:
|
||||
# If a1 is empty, propagate a2.
|
||||
# No matter whether a2 is a real value, empty, or None,
|
||||
# propagating each of these is correct if a1 is empty.
|
||||
return a2
|
||||
if a2 is None:
|
||||
# If we're ignoring nulls, propagate a1; otherwise, propagate None.
|
||||
return a1 if ignore_nulls else None
|
||||
unwrapped_a2, a2_has_data = _unwrap_acc(a2)
|
||||
if not a2_has_data:
|
||||
# If a2 is empty, propagate a1.
|
||||
return a1
|
||||
a = merge(unwrapped_a1, unwrapped_a2)
|
||||
return _wrap_acc(a, has_data=True)
|
||||
|
||||
return _merge
|
||||
|
||||
|
||||
def _null_wrap_finalize(
|
||||
finalize: Callable[[AggType], AggType]
|
||||
) -> Callable[[AggType], U]:
|
||||
"""
|
||||
Wrap finalizer with null handling.
|
||||
|
||||
If the accumulation is empty or None, the returned finalizer returns None.
|
||||
|
||||
Args:
|
||||
finalize: The core finalizing function to wrap.
|
||||
|
||||
Returns:
|
||||
A new finalizing function that handles nulls.
|
||||
"""
|
||||
|
||||
def _finalize(a: AggType) -> U:
|
||||
if a is None:
|
||||
return None
|
||||
a, has_data = _unwrap_acc(a)
|
||||
if not has_data:
|
||||
return None
|
||||
return finalize(a)
|
||||
|
||||
return _finalize
|
||||
|
||||
|
||||
LazyModule = Union[None, bool, ModuleType]
|
||||
_pandas: LazyModule = None
|
||||
|
||||
|
||||
def _lazy_import_pandas() -> LazyModule:
|
||||
global _pandas
|
||||
if _pandas is None:
|
||||
try:
|
||||
import pandas as _pandas
|
||||
except ModuleNotFoundError:
|
||||
# If module is not found, set _pandas to False so we won't
|
||||
# keep trying to import it on every _lazy_import_pandas() call.
|
||||
_pandas = False
|
||||
return _pandas
|
||||
|
||||
|
||||
def _is_null(r: Any):
|
||||
pd = _lazy_import_pandas()
|
||||
if pd:
|
||||
return pd.isnull(r)
|
||||
try:
|
||||
return np.isnan(r)
|
||||
except TypeError:
|
||||
return r is None
|
|
@ -2264,7 +2264,7 @@ def test_split(ray_start_regular_shared):
|
|||
assert [1] * 10 + [0] == [
|
||||
dataset._blocks.initial_num_blocks() for dataset in datasets
|
||||
]
|
||||
assert 190 == sum([dataset.sum() for dataset in datasets])
|
||||
assert 190 == sum([dataset.sum() or 0 for dataset in datasets])
|
||||
|
||||
|
||||
def test_split_hints(ray_start_regular_shared):
|
||||
|
@ -3271,12 +3271,65 @@ def test_groupby_arrow_sum(ray_start_regular_shared, num_parts):
|
|||
{"A": 1, "sum(B)": 1617},
|
||||
{"A": 2, "sum(B)": 1650},
|
||||
]
|
||||
|
||||
# Test built-in sum aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(
|
||||
[{"A": (x % 3), "B": x} for x in xs] + [{"A": 0, "B": None}]
|
||||
)
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.sum("B")
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "sum(B)": 1683},
|
||||
{"A": 1, "sum(B)": 1617},
|
||||
{"A": 2, "sum(B)": 1650},
|
||||
]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.sum("B", ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "sum(B)": None},
|
||||
{"A": 1, "sum(B)": 1617},
|
||||
{"A": 2, "sum(B)": 1650},
|
||||
]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([{"A": (x % 3), "B": None} for x in xs])
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
.sum("B")
|
||||
)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "sum(B)": None},
|
||||
{"A": 1, "sum(B)": None},
|
||||
{"A": 2, "sum(B)": None},
|
||||
]
|
||||
|
||||
# Test built-in global sum aggregation
|
||||
assert (
|
||||
ray.data.from_items([{"A": x} for x in xs]).repartition(num_parts).sum("A")
|
||||
== 4950
|
||||
)
|
||||
assert ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).sum("value") == 0
|
||||
|
||||
# Test empty dataset
|
||||
assert (
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).sum("value") is None
|
||||
)
|
||||
|
||||
# Test built-in global sum aggregation with nans
|
||||
nan_ds = ray.data.from_items([{"A": x} for x in xs] + [{"A": None}]).repartition(
|
||||
num_parts
|
||||
)
|
||||
assert nan_ds.sum("A") == 4950
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.sum("A", ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([{"A": None}] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.sum("A") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3299,12 +3352,64 @@ def test_groupby_arrow_min(ray_start_regular_shared, num_parts):
|
|||
{"A": 1, "min(B)": 1},
|
||||
{"A": 2, "min(B)": 2},
|
||||
]
|
||||
|
||||
# Test built-in min aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(
|
||||
[{"A": (x % 3), "B": x} for x in xs] + [{"A": 0, "B": None}]
|
||||
)
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.min("B")
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "min(B)": 0},
|
||||
{"A": 1, "min(B)": 1},
|
||||
{"A": 2, "min(B)": 2},
|
||||
]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.min("B", ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "min(B)": None},
|
||||
{"A": 1, "min(B)": 1},
|
||||
{"A": 2, "min(B)": 2},
|
||||
]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([{"A": (x % 3), "B": None} for x in xs])
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
.min("B")
|
||||
)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "min(B)": None},
|
||||
{"A": 1, "min(B)": None},
|
||||
{"A": 2, "min(B)": None},
|
||||
]
|
||||
|
||||
# Test built-in global min aggregation
|
||||
assert (
|
||||
ray.data.from_items([{"A": x} for x in xs]).repartition(num_parts).min("A") == 0
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).min("value")
|
||||
|
||||
# Test empty dataset
|
||||
assert (
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).min("value") is None
|
||||
)
|
||||
|
||||
# Test built-in global min aggregation with nans
|
||||
nan_ds = ray.data.from_items([{"A": x} for x in xs] + [{"A": None}]).repartition(
|
||||
num_parts
|
||||
)
|
||||
assert nan_ds.min("A") == 0
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.min("A", ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([{"A": None}] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.min("A") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3327,13 +3432,65 @@ def test_groupby_arrow_max(ray_start_regular_shared, num_parts):
|
|||
{"A": 1, "max(B)": 97},
|
||||
{"A": 2, "max(B)": 98},
|
||||
]
|
||||
|
||||
# Test built-in max aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(
|
||||
[{"A": (x % 3), "B": x} for x in xs] + [{"A": 0, "B": None}]
|
||||
)
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.max("B")
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "max(B)": 99},
|
||||
{"A": 1, "max(B)": 97},
|
||||
{"A": 2, "max(B)": 98},
|
||||
]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.max("B", ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "max(B)": None},
|
||||
{"A": 1, "max(B)": 97},
|
||||
{"A": 2, "max(B)": 98},
|
||||
]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([{"A": (x % 3), "B": None} for x in xs])
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
.max("B")
|
||||
)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "max(B)": None},
|
||||
{"A": 1, "max(B)": None},
|
||||
{"A": 2, "max(B)": None},
|
||||
]
|
||||
|
||||
# Test built-in global max aggregation
|
||||
assert (
|
||||
ray.data.from_items([{"A": x} for x in xs]).repartition(num_parts).max("A")
|
||||
== 99
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).max("value")
|
||||
|
||||
# Test empty dataset
|
||||
assert (
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).max("value") is None
|
||||
)
|
||||
|
||||
# Test built-in global max aggregation with nans
|
||||
nan_ds = ray.data.from_items([{"A": x} for x in xs] + [{"A": None}]).repartition(
|
||||
num_parts
|
||||
)
|
||||
assert nan_ds.max("A") == 99
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.max("A", ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([{"A": None}] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.max("A") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3356,13 +3513,65 @@ def test_groupby_arrow_mean(ray_start_regular_shared, num_parts):
|
|||
{"A": 1, "mean(B)": 49.0},
|
||||
{"A": 2, "mean(B)": 50.0},
|
||||
]
|
||||
|
||||
# Test built-in mean aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(
|
||||
[{"A": (x % 3), "B": x} for x in xs] + [{"A": 0, "B": None}]
|
||||
)
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.mean("B")
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "mean(B)": 49.5},
|
||||
{"A": 1, "mean(B)": 49.0},
|
||||
{"A": 2, "mean(B)": 50.0},
|
||||
]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.mean("B", ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "mean(B)": None},
|
||||
{"A": 1, "mean(B)": 49.0},
|
||||
{"A": 2, "mean(B)": 50.0},
|
||||
]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([{"A": (x % 3), "B": None} for x in xs])
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
.mean("B")
|
||||
)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert [row.as_pydict() for row in nan_agg_ds.sort("A").iter_rows()] == [
|
||||
{"A": 0, "mean(B)": None},
|
||||
{"A": 1, "mean(B)": None},
|
||||
{"A": 2, "mean(B)": None},
|
||||
]
|
||||
|
||||
# Test built-in global mean aggregation
|
||||
assert (
|
||||
ray.data.from_items([{"A": x} for x in xs]).repartition(num_parts).mean("A")
|
||||
== 49.5
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).mean("value")
|
||||
|
||||
# Test empty dataset
|
||||
assert (
|
||||
ray.data.range_arrow(10).filter(lambda r: r["value"] > 10).mean("value") is None
|
||||
)
|
||||
|
||||
# Test built-in global mean aggregation with nans
|
||||
nan_ds = ray.data.from_items([{"A": x} for x in xs] + [{"A": None}]).repartition(
|
||||
num_parts
|
||||
)
|
||||
assert nan_ds.mean("A") == 49.5
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.mean("A", ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([{"A": None}] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.mean("A") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3387,6 +3596,35 @@ def test_groupby_arrow_std(ray_start_regular_shared, num_parts):
|
|||
result = agg_ds.to_pandas()["std(B)"].to_numpy()
|
||||
expected = df.groupby("A")["B"].std(ddof=0).to_numpy()
|
||||
np.testing.assert_array_almost_equal(result, expected)
|
||||
|
||||
# Test built-in std aggregation with nans
|
||||
nan_df = pd.DataFrame({"A": [x % 3 for x in xs] + [0], "B": xs + [None]})
|
||||
nan_grouped_ds = ray.data.from_pandas(nan_df).repartition(num_parts).groupby("A")
|
||||
nan_agg_ds = nan_grouped_ds.std("B")
|
||||
assert nan_agg_ds.count() == 3
|
||||
result = nan_agg_ds.to_pandas()["std(B)"].to_numpy()
|
||||
expected = nan_df.groupby("A")["B"].std().to_numpy()
|
||||
np.testing.assert_array_almost_equal(result, expected)
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.std("B", ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
result = nan_agg_ds.to_pandas()["std(B)"].to_numpy()
|
||||
expected = nan_df.groupby("A")["B"].std()
|
||||
expected[0] = None
|
||||
np.testing.assert_array_almost_equal(result, expected)
|
||||
# Test all nans
|
||||
nan_df = pd.DataFrame({"A": [x % 3 for x in xs], "B": [None] * len(xs)})
|
||||
nan_agg_ds = (
|
||||
ray.data.from_pandas(nan_df)
|
||||
.repartition(num_parts)
|
||||
.groupby("A")
|
||||
.std("B", ignore_nulls=False)
|
||||
)
|
||||
assert nan_agg_ds.count() == 3
|
||||
result = nan_agg_ds.to_pandas()["std(B)"].to_numpy()
|
||||
expected = pd.Series([None] * 3)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
# Test built-in global std aggregation
|
||||
df = pd.DataFrame({"A": xs})
|
||||
assert math.isclose(
|
||||
|
@ -3397,11 +3635,22 @@ def test_groupby_arrow_std(ray_start_regular_shared, num_parts):
|
|||
ray.data.from_pandas(df).repartition(num_parts).std("A", ddof=0),
|
||||
df["A"].std(ddof=0),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.from_pandas(pd.DataFrame({"A": []})).std("A")
|
||||
|
||||
# Test empty dataset
|
||||
assert ray.data.from_pandas(pd.DataFrame({"A": []})).std("A") is None
|
||||
# Test edge cases
|
||||
assert ray.data.from_pandas(pd.DataFrame({"A": [3]})).std("A") == 0
|
||||
|
||||
# Test built-in global std aggregation with nans
|
||||
nan_df = pd.DataFrame({"A": xs + [None]})
|
||||
nan_ds = ray.data.from_pandas(nan_df).repartition(num_parts)
|
||||
assert math.isclose(nan_ds.std("A"), df["A"].std())
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.std("A", ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([{"A": None}] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.std("A") is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
def test_groupby_arrow_multicolumn(ray_start_regular_shared, num_parts):
|
||||
|
@ -3421,6 +3670,7 @@ def test_groupby_arrow_multicolumn(ray_start_regular_shared, num_parts):
|
|||
{"A": 1, "mean(B)": 49.0, "mean(C)": 98.0},
|
||||
{"A": 2, "mean(B)": 50.0, "mean(C)": 100.0},
|
||||
]
|
||||
|
||||
# Test that unspecified agg column ==> agg on all columns except for
|
||||
# groupby keys.
|
||||
agg_ds = ray.data.from_pandas(df).repartition(num_parts).groupby("A").mean()
|
||||
|
@ -3430,6 +3680,7 @@ def test_groupby_arrow_multicolumn(ray_start_regular_shared, num_parts):
|
|||
{"A": 1, "mean(B)": 49.0, "mean(C)": 98.0},
|
||||
{"A": 2, "mean(B)": 50.0, "mean(C)": 100.0},
|
||||
]
|
||||
|
||||
# Test built-in global mean aggregation
|
||||
df = pd.DataFrame({"A": xs, "B": [2 * x for x in xs]})
|
||||
result_row = ray.data.from_pandas(df).repartition(num_parts).mean(["A", "B"])
|
||||
|
@ -3511,6 +3762,7 @@ def test_groupby_arrow_multi_agg(ray_start_regular_shared, num_parts):
|
|||
np.testing.assert_array_equal(result, expected)
|
||||
# Test built-in global std aggregation
|
||||
df = pd.DataFrame({"A": xs})
|
||||
|
||||
result_row = (
|
||||
ray.data.from_pandas(df)
|
||||
.repartition(num_parts)
|
||||
|
@ -3549,6 +3801,7 @@ def test_groupby_simple(ray_start_regular_shared):
|
|||
]
|
||||
random.shuffle(xs)
|
||||
ds = ray.data.from_items(xs, parallelism=parallelism)
|
||||
|
||||
# Mean aggregation
|
||||
agg_ds = ds.groupby(lambda r: r[0]).aggregate(
|
||||
AggregateFn(
|
||||
|
@ -3625,9 +3878,50 @@ def test_groupby_simple_sum(ray_start_regular_shared, num_parts):
|
|||
)
|
||||
assert agg_ds.count() == 3
|
||||
assert agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, 1683), (1, 1617), (2, 1650)]
|
||||
|
||||
# Test built-in sum aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(xs + [None])
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: int(x or 0) % 3)
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.sum()
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [
|
||||
(0, 1683),
|
||||
(1, 1617),
|
||||
(2, 1650),
|
||||
]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.sum(ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [
|
||||
(0, None),
|
||||
(1, 1617),
|
||||
(2, 1650),
|
||||
]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([None] * len(xs))
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: 0)
|
||||
.sum()
|
||||
)
|
||||
assert nan_agg_ds.count() == 1
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)]
|
||||
|
||||
# Test built-in global sum aggregation
|
||||
assert ray.data.from_items(xs).repartition(num_parts).sum() == 4950
|
||||
assert ray.data.range(10).filter(lambda r: r > 10).sum() == 0
|
||||
assert ray.data.range(10).filter(lambda r: r > 10).sum() is None
|
||||
|
||||
# Test built-in global sum aggregation with nans
|
||||
nan_ds = ray.data.from_items(xs + [None]).repartition(num_parts)
|
||||
assert nan_ds.sum() == 4950
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.sum(ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([None] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.sum() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3643,10 +3937,42 @@ def test_groupby_simple_min(ray_start_regular_shared, num_parts):
|
|||
)
|
||||
assert agg_ds.count() == 3
|
||||
assert agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, 0), (1, 1), (2, 2)]
|
||||
|
||||
# Test built-in min aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(xs + [None])
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: int(x or 0) % 3)
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.min()
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, 0), (1, 1), (2, 2)]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.min(ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, None), (1, 1), (2, 2)]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([None] * len(xs))
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: 0)
|
||||
.min()
|
||||
)
|
||||
assert nan_agg_ds.count() == 1
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)]
|
||||
|
||||
# Test built-in global min aggregation
|
||||
assert ray.data.from_items(xs).repartition(num_parts).min() == 0
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.range(10).filter(lambda r: r > 10).min()
|
||||
assert ray.data.range(10).filter(lambda r: r > 10).min() is None
|
||||
|
||||
# Test built-in global min aggregation with nans
|
||||
nan_ds = ray.data.from_items(xs + [None]).repartition(num_parts)
|
||||
assert nan_ds.min() == 0
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.min(ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([None] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.min() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3662,10 +3988,42 @@ def test_groupby_simple_max(ray_start_regular_shared, num_parts):
|
|||
)
|
||||
assert agg_ds.count() == 3
|
||||
assert agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, 99), (1, 97), (2, 98)]
|
||||
|
||||
# Test built-in max aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(xs + [None])
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: int(x or 0) % 3)
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.max()
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, 99), (1, 97), (2, 98)]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.max(ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, None), (1, 97), (2, 98)]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([None] * len(xs))
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: 0)
|
||||
.max()
|
||||
)
|
||||
assert nan_agg_ds.count() == 1
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)]
|
||||
|
||||
# Test built-in global max aggregation
|
||||
assert ray.data.from_items(xs).repartition(num_parts).max() == 99
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.range(10).filter(lambda r: r > 10).max()
|
||||
assert ray.data.range(10).filter(lambda r: r > 10).max() is None
|
||||
|
||||
# Test built-in global max aggregation with nans
|
||||
nan_ds = ray.data.from_items(xs + [None]).repartition(num_parts)
|
||||
assert nan_ds.max() == 99
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.max(ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([None] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.max() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3681,10 +4039,51 @@ def test_groupby_simple_mean(ray_start_regular_shared, num_parts):
|
|||
)
|
||||
assert agg_ds.count() == 3
|
||||
assert agg_ds.sort(key=lambda r: r[0]).take(3) == [(0, 49.5), (1, 49.0), (2, 50.0)]
|
||||
|
||||
# Test built-in mean aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(xs + [None])
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: int(x or 0) % 3)
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.mean()
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [
|
||||
(0, 49.5),
|
||||
(1, 49.0),
|
||||
(2, 50.0),
|
||||
]
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.mean(ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(3) == [
|
||||
(0, None),
|
||||
(1, 49.0),
|
||||
(2, 50.0),
|
||||
]
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([None] * len(xs))
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: 0)
|
||||
.mean()
|
||||
)
|
||||
assert nan_agg_ds.count() == 1
|
||||
assert nan_agg_ds.sort(key=lambda r: r[0]).take(1) == [(0, None)]
|
||||
|
||||
# Test built-in global mean aggregation
|
||||
assert ray.data.from_items(xs).repartition(num_parts).mean() == 49.5
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.range(10).filter(lambda r: r > 10).mean()
|
||||
# Test empty dataset
|
||||
assert ray.data.range(10).filter(lambda r: r > 10).mean() is None
|
||||
|
||||
# Test built-in global mean aggregation with nans
|
||||
nan_ds = ray.data.from_items(xs + [None]).repartition(num_parts)
|
||||
assert nan_ds.mean() == 49.5
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.mean(ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([None] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.mean() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
@ -3721,6 +4120,48 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts):
|
|||
result_df = pd.DataFrame({"A": list(groups), "B": list(stds)})
|
||||
result_df = result_df.set_index("A")
|
||||
pd.testing.assert_series_equal(result_df["B"], expected)
|
||||
|
||||
# Test built-in std aggregation with nans
|
||||
nan_grouped_ds = (
|
||||
ray.data.from_items(xs + [None])
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: int(x or 0) % 3)
|
||||
)
|
||||
nan_agg_ds = nan_grouped_ds.std()
|
||||
assert nan_agg_ds.count() == 3
|
||||
nan_df = pd.DataFrame({"A": [x % 3 for x in xs] + [0], "B": xs + [None]})
|
||||
expected = nan_df.groupby("A")["B"].std()
|
||||
result = nan_agg_ds.sort(key=lambda r: r[0]).take(3)
|
||||
groups, stds = zip(*result)
|
||||
result_df = pd.DataFrame({"A": list(groups), "B": list(stds)})
|
||||
result_df = result_df.set_index("A")
|
||||
pd.testing.assert_series_equal(result_df["B"], expected)
|
||||
# Test ignore_nulls=False
|
||||
nan_agg_ds = nan_grouped_ds.std(ignore_nulls=False)
|
||||
assert nan_agg_ds.count() == 3
|
||||
expected = nan_df.groupby("A")["B"].std()
|
||||
expected[0] = None
|
||||
result = nan_agg_ds.sort(key=lambda r: r[0]).take(3)
|
||||
groups, stds = zip(*result)
|
||||
result_df = pd.DataFrame({"A": list(groups), "B": list(stds)})
|
||||
result_df = result_df.set_index("A")
|
||||
pd.testing.assert_series_equal(result_df["B"], expected)
|
||||
# Test all nans
|
||||
nan_agg_ds = (
|
||||
ray.data.from_items([None] * len(xs))
|
||||
.repartition(num_parts)
|
||||
.groupby(lambda x: 0)
|
||||
.std(ignore_nulls=False)
|
||||
)
|
||||
assert nan_agg_ds.count() == 1
|
||||
expected = pd.Series([None], name="B")
|
||||
expected.index.rename("A", inplace=True)
|
||||
result = nan_agg_ds.sort(key=lambda r: r[0]).take(1)
|
||||
groups, stds = zip(*result)
|
||||
result_df = pd.DataFrame({"A": list(groups), "B": list(stds)})
|
||||
result_df = result_df.set_index("A")
|
||||
pd.testing.assert_series_equal(result_df["B"], expected)
|
||||
|
||||
# Test built-in global std aggregation
|
||||
assert math.isclose(
|
||||
ray.data.from_items(xs).repartition(num_parts).std(), pd.Series(xs).std()
|
||||
|
@ -3730,11 +4171,21 @@ def test_groupby_simple_std(ray_start_regular_shared, num_parts):
|
|||
ray.data.from_items(xs).repartition(num_parts).std(ddof=0),
|
||||
pd.Series(xs).std(ddof=0),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.from_items([]).std()
|
||||
|
||||
# Test empty dataset
|
||||
assert ray.data.from_items([]).std() is None
|
||||
# Test edge cases
|
||||
assert ray.data.from_items([3]).std() == 0
|
||||
|
||||
# Test built-in global std aggregation with nans
|
||||
nan_ds = ray.data.from_items(xs + [None]).repartition(num_parts)
|
||||
assert math.isclose(nan_ds.std(), pd.Series(xs).std())
|
||||
# Test ignore_nulls=False
|
||||
assert nan_ds.std(ignore_nulls=False) is None
|
||||
# Test all nans
|
||||
nan_ds = ray.data.from_items([None] * len(xs)).repartition(num_parts)
|
||||
assert nan_ds.std() is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
def test_groupby_simple_multilambda(ray_start_regular_shared, num_parts):
|
||||
|
@ -3760,10 +4211,12 @@ def test_groupby_simple_multilambda(ray_start_regular_shared, num_parts):
|
|||
assert ray.data.from_items([[x, 2 * x] for x in xs]).repartition(num_parts).mean(
|
||||
[lambda x: x[0], lambda x: x[1]]
|
||||
) == (49.5, 99.0)
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.from_items([[x, 2 * x] for x in range(10)]).filter(
|
||||
lambda r: r[0] > 10
|
||||
).mean([lambda x: x[0], lambda x: x[1]])
|
||||
assert (
|
||||
ray.data.from_items([[x, 2 * x] for x in range(10)])
|
||||
.filter(lambda r: r[0] > 10)
|
||||
.mean([lambda x: x[0], lambda x: x[1]])
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_parts", [1, 30])
|
||||
|
|
Loading…
Add table
Reference in a new issue