[Dataset] [DataFrame 2/n] Add pandas block format implementation (partial) (#20988)

This PR adds pandas block format support by implementing `PandasRow`, `PandasBlockBuilder`, `PandasBlockAccessor`.

Note that `sort_and_partition`, `combine`, `merge_sorted_blocks`, `aggregate_combined_blocks` in `PandasBlockAccessor` redirects to arrow block format implementation for now. They'll be implemented in a later PR.

Co-authored-by: Clark Zinzow <clarkzinzow@gmail.com>
Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Kai Yang 2022-01-15 17:28:34 +08:00 committed by GitHub
parent 26057c433f
commit 4a55d10bb1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 381 additions and 94 deletions

View file

@ -273,6 +273,7 @@ def inference(dataset, model_cls: type, batch_size: int, result_path: str,
model_cls,
compute="actors",
batch_size=batch_size,
batch_format="pandas",
num_gpus=num_gpus,
num_cpus=0) \
.write_parquet(result_path)
@ -578,8 +579,8 @@ if __name__ == "__main__":
read_dataset(data_path))
num_columns = len(train_dataset.schema().names)
# remove label column and internal Arrow column.
num_features = num_columns - 2
# remove label column.
num_features = num_columns - 1
NUM_EPOCHS = 2
BATCH_SIZE = 512
@ -681,9 +682,9 @@ if __name__ == "__main__":
self.model = load_model_func().to(self.device)
def __call__(self, batch) -> "pd.DataFrame":
tensor = torch.FloatTensor(batch.to_pandas().values).to(
self.device)
return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
tensor = torch.FloatTensor(batch.values).to(self.device)
return pd.DataFrame(
self.model(tensor).cpu().detach().numpy(), columns=["value"])
inference_dataset = preprocessor.preprocess_inference_data(
read_dataset(inference_path))

View file

@ -25,7 +25,7 @@ AggType = TypeVar("AggType")
#
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``.
Block = Union[List[T], "pyarrow.Table", bytes]
Block = Union[List[T], "pyarrow.Table", "pandas.DataFrame", bytes]
# A list of block references pending computation by a single task. For example,
# this may be the output of a task reading a file.
@ -196,11 +196,16 @@ class BlockAccessor(Generic[T]):
"""Create a block accessor for the given block."""
_check_pyarrow_version()
import pyarrow
import pandas
if isinstance(block, pyarrow.Table):
from ray.data.impl.arrow_block import \
ArrowBlockAccessor
return ArrowBlockAccessor(block)
elif isinstance(block, pandas.DataFrame):
from ray.data.impl.pandas_block import \
PandasBlockAccessor
return PandasBlockAccessor(block)
elif isinstance(block, bytes):
from ray.data.impl.arrow_block import \
ArrowBlockAccessor

View file

@ -14,6 +14,10 @@ DEFAULT_TARGET_MAX_BLOCK_SIZE = 2048 * 1024 * 1024
# Whether block splitting is on by default
DEFAULT_BLOCK_SPLITTING_ENABLED = False
# Whether pandas block format is enabled.
# TODO (kfstorm): Remove this once stable.
DEFAULT_ENABLE_PANDAS_BLOCK = True
@DeveloperAPI
class DatasetContext:
@ -23,12 +27,18 @@ class DatasetContext:
from the driver and remote workers via DatasetContext.get_current().
"""
def __init__(self, block_owner: ray.actor.ActorHandle,
block_splitting_enabled: bool, target_max_block_size: int):
def __init__(
self,
block_owner: ray.actor.ActorHandle,
block_splitting_enabled: bool,
target_max_block_size: int,
enable_pandas_block: bool,
):
"""Private constructor (use get_current() instead)."""
self.block_owner = block_owner
self.block_splitting_enabled = block_splitting_enabled
self.target_max_block_size = target_max_block_size
self.enable_pandas_block = enable_pandas_block
@staticmethod
def get_current() -> "DatasetContext":
@ -45,7 +55,9 @@ class DatasetContext:
_default_context = DatasetContext(
block_owner=None,
block_splitting_enabled=DEFAULT_BLOCK_SPLITTING_ENABLED,
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE)
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE,
enable_pandas_block=DEFAULT_ENABLE_PANDAS_BLOCK,
)
if _default_context.block_owner is None:
owner = _DesignatedBlockOwner.options(

View file

@ -44,6 +44,7 @@ from ray.data.impl.shuffle import simple_shuffle, _shuffle_reduce
from ray.data.impl.sort import sort_impl
from ray.data.impl.block_list import BlockList
from ray.data.impl.lazy_block_list import LazyBlockList
from ray.data.impl.table_block import TableRow
from ray.data.impl.delegating_block_builder import DelegatingBlockBuilder
# An output type of iter_batches() determined by the batch_format parameter.
@ -230,11 +231,9 @@ class Dataset(Generic[T]):
"or 'pyarrow', got: {}".format(batch_format))
applied = fn(view)
if isinstance(applied, list) or isinstance(applied, pa.Table):
applied = applied
elif isinstance(applied, pd.core.frame.DataFrame):
applied = pa.Table.from_pandas(applied)
else:
if not (isinstance(applied, list)
or isinstance(applied, pa.Table)
or isinstance(applied, pd.core.frame.DataFrame)):
raise ValueError("The map batches UDF returned the value "
f"{applied}, which is not allowed. "
"The return type must be either list, "
@ -403,12 +402,15 @@ class Dataset(Generic[T]):
# Handle empty blocks.
if len(new_blocks) < num_blocks:
from ray.data.impl.arrow_block import ArrowBlockBuilder
from ray.data.impl.pandas_block import PandasBlockBuilder
from ray.data.impl.simple_block import SimpleBlockBuilder
num_empties = num_blocks - len(new_blocks)
dataset_format = self._dataset_format()
if dataset_format == "arrow":
builder = ArrowBlockBuilder()
elif dataset_format == "pandas":
builder = PandasBlockBuilder()
else:
builder = SimpleBlockBuilder()
empty_block = builder.build()
@ -938,7 +940,7 @@ class Dataset(Generic[T]):
# Dataset is empty/cleared, let downstream ops handle this.
return on
if dataset_format == "arrow":
if dataset_format == "arrow" or dataset_format == "pandas":
# This should be cached from the ._dataset_format() check, so we
# don't fetch and we assert that the schema is not None.
schema = self.schema(fetch_if_missing=False)
@ -971,32 +973,35 @@ class Dataset(Generic[T]):
and isinstance(on[0], str)):
raise ValueError(
"Can't aggregate on a column when using a simple Dataset; "
"use a callable `on` argument or use an Arrow Dataset "
"instead of a simple Dataset.")
"use a callable `on` argument or use an Arrow or Pandas"
" Dataset instead of a simple Dataset.")
return on
def _dataset_format(self) -> str:
"""Determine the format of the dataset. Possible values are: "arrow",
"simple".
"pandas", "simple".
This may block; if the schema is unknown, this will synchronously fetch
the schema for the first block.
"""
# We need schema to properly validate, so synchronously
# fetch it if necessary.
schema = self.schema(fetch_if_missing=True)
if schema is None:
raise ValueError(
"Dataset is empty or cleared, can't determine the format of "
"the dataset.")
try:
import pyarrow as pa
except ModuleNotFoundError:
return "simple"
else:
# We need schema to properly validate, so synchronously
# fetch it if necessary.
schema = self.schema(fetch_if_missing=True)
if schema is None:
raise ValueError(
"Dataset is empty or cleared, can't determine the format"
" of the dataset")
if isinstance(schema, pa.Schema):
return "arrow"
return "simple"
except ModuleNotFoundError:
pass
from ray.data.impl.pandas_block import PandasBlockSchema
if isinstance(schema, PandasBlockSchema):
return "pandas"
return "simple"
def _aggregate_on(self, agg_cls: type, on: Optional["AggregateOnTs"],
*args, **kwargs):
@ -1025,6 +1030,18 @@ class Dataset(Generic[T]):
on = [on]
return [agg_cls(on_, *args, **kwargs) for on_ in on]
def _aggregate_result(self, result: Union[Tuple, TableRow]) -> U:
if len(result) == 1:
if isinstance(result, tuple):
return result[0]
else:
# NOTE (kfstorm): We cannot call `result[0]` directly on
# `PandasRow` because indexing a column with position is not
# supported by pandas.
return list(result.values())[0]
else:
return result
def sum(self, on: Optional["AggregateOnTs"] = None) -> U:
"""Compute sum over entire dataset.
@ -1075,10 +1092,8 @@ class Dataset(Generic[T]):
ret = self._aggregate_on(Sum, on)
if ret is None:
return 0
elif len(ret) == 1:
return ret[0]
else:
return ret
return self._aggregate_result(ret)
def min(self, on: Optional["AggregateOnTs"] = None) -> U:
"""Compute minimum over entire dataset.
@ -1130,10 +1145,8 @@ class Dataset(Generic[T]):
ret = self._aggregate_on(Min, on)
if ret is None:
raise ValueError("Cannot compute min on an empty dataset")
elif len(ret) == 1:
return ret[0]
else:
return ret
return self._aggregate_result(ret)
def max(self, on: Optional["AggregateOnTs"] = None) -> U:
"""Compute maximum over entire dataset.
@ -1185,10 +1198,8 @@ class Dataset(Generic[T]):
ret = self._aggregate_on(Max, on)
if ret is None:
raise ValueError("Cannot compute max on an empty dataset")
elif len(ret) == 1:
return ret[0]
else:
return ret
return self._aggregate_result(ret)
def mean(self, on: Optional["AggregateOnTs"] = None) -> U:
"""Compute mean over entire dataset.
@ -1240,10 +1251,8 @@ class Dataset(Generic[T]):
ret = self._aggregate_on(Mean, on)
if ret is None:
raise ValueError("Cannot compute mean on an empty dataset")
elif len(ret) == 1:
return ret[0]
else:
return ret
return self._aggregate_result(ret)
def std(self, on: Optional["AggregateOnTs"] = None, ddof: int = 1) -> U:
"""Compute standard deviation over entire dataset.
@ -1305,10 +1314,8 @@ class Dataset(Generic[T]):
ret = self._aggregate_on(Std, on, ddof=ddof)
if ret is None:
raise ValueError("Cannot compute std on an empty dataset")
elif len(ret) == 1:
return ret[0]
else:
return ret
return self._aggregate_result(ret)
def sort(self,
key: Union[None, str, List[str], Callable[[T], Any]] = None,
@ -2263,10 +2270,10 @@ Dict[str, List[str]]]): The names of the columns
def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame":
"""Convert this dataset into a single Pandas DataFrame.
This is only supported for datasets convertible to Arrow records. An
error is raised if the number of records exceeds the provided limit.
Note that you can use ``.limit()`` on the dataset beforehand to
truncate the dataset manually.
This is only supported for datasets convertible to Arrow or Pandas
records. An error is raised if the number of records exceeds the
provided limit. Note that you can use ``.limit()`` on the dataset
beforehand to truncate the dataset manually.
Time complexity: O(dataset size)

View file

@ -79,7 +79,7 @@ class ArrowBlockAccessor(TableBlockAccessor):
view = _copy_table(view)
return view
def random_shuffle(self, random_seed: Optional[int]) -> List[T]:
def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
random = np.random.RandomState(random_seed)
return self._table.take(random.permutation(self.num_rows()))

View file

@ -4,6 +4,7 @@ from ray.data.block import Block, T, BlockAccessor
from ray.data.impl.block_builder import BlockBuilder
from ray.data.impl.simple_block import SimpleBlockBuilder
from ray.data.impl.arrow_block import ArrowRow, ArrowBlockBuilder
from ray.data.impl.pandas_block import PandasRow, PandasBlockBuilder
class DelegatingBlockBuilder(BlockBuilder[T]):
@ -13,6 +14,7 @@ class DelegatingBlockBuilder(BlockBuilder[T]):
def add(self, item: Any) -> None:
if self._builder is None:
# TODO (kfstorm): Maybe we can use Pandas block format for dict.
if isinstance(item, dict) or isinstance(item, ArrowRow):
import pyarrow
try:
@ -22,6 +24,8 @@ class DelegatingBlockBuilder(BlockBuilder[T]):
self._builder = ArrowBlockBuilder()
except (TypeError, pyarrow.lib.ArrowInvalid):
self._builder = SimpleBlockBuilder()
elif isinstance(item, PandasRow):
self._builder = PandasBlockBuilder()
else:
self._builder = SimpleBlockBuilder()
self._builder.add(item)

View file

@ -0,0 +1,188 @@
from typing import Dict, List, Tuple, Any, TypeVar, Optional, TYPE_CHECKING
import collections
import numpy as np
try:
import pandas
except ImportError:
pandas = None
from ray.data.block import BlockAccessor, BlockMetadata
from ray.data.impl.table_block import TableBlockAccessor, TableRow, \
TableBlockBuilder, SortKeyT, GroupKeyT
from ray.data.impl.arrow_block import ArrowBlockAccessor
from ray.data.aggregate import AggregateFn
if TYPE_CHECKING:
import pyarrow
import pandas
T = TypeVar("T")
class PandasRow(TableRow):
def as_pydict(self) -> dict:
return {k: v[0] for k, v in self._row.to_dict("list").items()}
def __getitem__(self, key: str) -> Any:
assert isinstance(key, str)
col = self._row[key]
if len(col) == 0:
return None
item = col.iloc[0]
try:
# Try to interpret this as a numpy-type value.
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
return item.item()
except AttributeError:
# Fallback to the original form.
return item
def __len__(self):
return self._row.shape[1]
class PandasBlockBuilder(TableBlockBuilder[T]):
def __init__(self):
if pandas is None:
raise ImportError("Run `pip install pandas` for Pandas support.")
super().__init__(pandas.DataFrame)
def _table_from_pydict(
self, columns: Dict[str, List[Any]]) -> "pandas.DataFrame":
return pandas.DataFrame(columns)
def _concat_tables(self,
tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
return pandas.concat(tables, ignore_index=True)
@staticmethod
def _empty_table() -> "pandas.DataFrame":
return pandas.DataFrame()
# This is to be compatible with pyarrow.lib.schema
# TODO (kfstorm): We need a format-independent way to represent schema.
PandasBlockSchema = collections.namedtuple("PandasBlockSchema",
["names", "types"])
class PandasBlockAccessor(TableBlockAccessor):
def __init__(self, table: "pandas.DataFrame"):
if pandas is None:
raise ImportError("Run `pip install pandas` for Pandas support.")
super().__init__(table)
def _create_table_row(self, row: "pandas.DataFrame") -> PandasRow:
return PandasRow(row)
def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame":
view = self._table[start:end]
if copy:
view = view.copy(deep=True)
return view
def random_shuffle(self, random_seed: Optional[int]) -> "pandas.DataFrame":
return self._table.sample(frac=1, random_state=random_seed)
def schema(self) -> PandasBlockSchema:
dtypes = self._table.dtypes
schema = PandasBlockSchema(
names=dtypes.index.tolist(), types=dtypes.values.tolist())
# Column names with non-str types of a pandas DataFrame is not
# supported by Ray Dataset.
if any(not isinstance(name, str) for name in schema.names):
raise ValueError(
"A Pandas DataFrame with column names of non-str types"
" is not supported by Ray Dataset. Column names of this"
f" DataFrame: {schema.names!r}.")
return schema
def to_pandas(self) -> "pandas.DataFrame":
return self._table
def to_numpy(self, column: str = None) -> np.ndarray:
if not column:
raise ValueError(
"`column` must be specified when calling .to_numpy() "
"on Pandas blocks.")
if column not in self._table.columns:
raise ValueError(
"Cannot find column {}, available columns: {}".format(
column, self._table.columns.tolist()))
return self._table[column].to_numpy()
def to_arrow(self) -> "pyarrow.Table":
import pyarrow
return pyarrow.table(self._table)
def num_rows(self) -> int:
return self._table.shape[0]
def size_bytes(self) -> int:
return self._table.memory_usage(index=True, deep=True).sum()
def _zip(self, acc: BlockAccessor) -> "pandas.DataFrame":
r = self.to_pandas().copy(deep=False)
s = acc.to_pandas()
for col_name in s.columns:
col = s[col_name]
# Ensure the column names are unique after zip.
if col_name in r.column_names:
i = 1
new_name = col_name
while new_name in r.column_names:
new_name = "{}_{}".format(col_name, i)
i += 1
col_name = new_name
r[col_name] = col
return r
@staticmethod
def builder() -> PandasBlockBuilder[T]:
return PandasBlockBuilder()
@staticmethod
def _empty_table() -> "pandas.DataFrame":
return PandasBlockBuilder._empty_table()
def _sample(self, n_samples: int, key: SortKeyT) -> "pandas.DataFrame":
return self._table[[k[0] for k in key]].sample(
n_samples, ignore_index=True)
def sort_and_partition(self, boundaries: List[T], key: SortKeyT,
descending: bool) -> List["pandas.DataFrame"]:
# TODO (kfstorm): A workaround to pass tests. Not efficient.
delegated_result = BlockAccessor.for_block(
self.to_arrow()).sort_and_partition(boundaries, key, descending)
return [
BlockAccessor.for_block(_).to_pandas() for _ in delegated_result
]
def combine(self, key: GroupKeyT,
aggs: Tuple[AggregateFn]) -> "pandas.DataFrame":
# TODO (kfstorm): A workaround to pass tests. Not efficient.
return BlockAccessor.for_block(self.to_arrow()).combine(
key, aggs).to_pandas()
@staticmethod
def merge_sorted_blocks(
blocks: List["pandas.DataFrame"], key: SortKeyT,
_descending: bool) -> Tuple["pandas.DataFrame", BlockMetadata]:
# TODO (kfstorm): A workaround to pass tests. Not efficient.
block, metadata = ArrowBlockAccessor.merge_sorted_blocks(
[BlockAccessor.for_block(block).to_arrow() for block in blocks],
key, _descending)
return BlockAccessor.for_block(block).to_pandas(), metadata
@staticmethod
def aggregate_combined_blocks(
blocks: List["pandas.DataFrame"], key: GroupKeyT,
aggs: Tuple[AggregateFn]
) -> Tuple["pandas.DataFrame", BlockMetadata]:
# TODO (kfstorm): A workaround to pass tests. Not efficient.
block, metadata = ArrowBlockAccessor.aggregate_combined_blocks(
[BlockAccessor.for_block(block).to_arrow() for block in blocks],
key, aggs)
return BlockAccessor.for_block(block).to_pandas(), metadata

View file

@ -71,7 +71,7 @@ class SimpleBlockAccessor(BlockAccessor):
def to_pandas(self) -> "pandas.DataFrame":
import pandas
return pandas.DataFrame(self._items)
return pandas.DataFrame({"value": self._items})
def to_numpy(self, column: str = None) -> np.ndarray:
if column:

View file

@ -534,8 +534,22 @@ def from_dask(df: "dask.DataFrame") -> Dataset[ArrowRow]:
partitions = df.to_delayed()
persisted_partitions = dask.persist(*partitions, scheduler=ray_dask_get)
return from_pandas_refs(
[next(iter(part.dask.values())) for part in persisted_partitions])
import pandas
def to_ref(df):
if isinstance(df, pandas.DataFrame):
return ray.put(df)
elif isinstance(df, ray.ObjectRef):
return df
else:
raise ValueError(
"Expected a Ray object ref or a Pandas DataFrame, "
f"got {type(df)}")
return from_pandas_refs([
to_ref(next(iter(part.dask.values()))) for part in persisted_partitions
])
@PublicAPI(stability="beta")
@ -600,6 +614,21 @@ def from_pandas_refs(dfs: Union[ObjectRef["pandas.DataFrame"], List[ObjectRef[
"""
if isinstance(dfs, ray.ObjectRef):
dfs = [dfs]
elif isinstance(dfs, list):
for df in dfs:
if not isinstance(df, ray.ObjectRef):
raise ValueError("Expected list of Ray object refs, "
f"got list containing {type(df)}")
else:
raise ValueError("Expected Ray object ref or list of Ray object refs, "
f"got {type(df)}")
context = DatasetContext.get_current()
if context.enable_pandas_block:
get_metadata = cached_remote_fn(_get_metadata)
metadata = [get_metadata.remote(df) for df in dfs]
return Dataset(
BlockList(dfs, ray.get(metadata)), 0, DatasetStats.TODO())
df_to_block = cached_remote_fn(_df_to_block, num_returns=2)
@ -707,7 +736,8 @@ def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]:
input_files=None, exec_stats=stats.build()))
def _get_metadata(table: "pyarrow.Table") -> BlockMetadata:
def _get_metadata(
table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMetadata:
stats = BlockExecStats.builder()
return BlockAccessor.for_block(table).get_metadata(
input_files=None, exec_stats=stats.build())

View file

@ -329,7 +329,7 @@ def test_batch_tensors(ray_start_regular_shared):
with pytest.raises(pa.lib.ArrowInvalid):
next(ds.iter_batches(batch_format="pyarrow"))
df = next(ds.iter_batches(batch_format="pandas"))
assert df.to_dict().keys() == {0, 1}
assert df.to_dict().keys() == {"value"}
def test_arrow_block_slice_copy():
@ -1156,34 +1156,56 @@ def test_repartition_shuffle_arrow(ray_start_regular_shared):
assert large._block_num_rows() == [500] * 20
def test_from_pandas(ray_start_regular_shared):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas([df1, df2])
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
@pytest.mark.parametrize("enable_pandas_block", [False, True])
def test_from_pandas(ray_start_regular_shared, enable_pandas_block):
ctx = ray.data.context.DatasetContext.get_current()
old_enable_pandas_block = ctx.enable_pandas_block
ctx.enable_pandas_block = enable_pandas_block
try:
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas([df1, df2])
assert ds._dataset_format(
) == "pandas" if enable_pandas_block else "arrow"
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
# test from single pandas dataframe
ds = ray.data.from_pandas(df1)
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
# test from single pandas dataframe
ds = ray.data.from_pandas(df1)
assert ds._dataset_format(
) == "pandas" if enable_pandas_block else "arrow"
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
finally:
ctx.enable_pandas_block = old_enable_pandas_block
def test_from_pandas_refs(ray_start_regular_shared):
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas_refs([ray.put(df1), ray.put(df2)])
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
@pytest.mark.parametrize("enable_pandas_block", [False, True])
def test_from_pandas_refs(ray_start_regular_shared, enable_pandas_block):
ctx = ray.data.context.DatasetContext.get_current()
old_enable_pandas_block = ctx.enable_pandas_block
ctx.enable_pandas_block = enable_pandas_block
try:
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
ds = ray.data.from_pandas_refs([ray.put(df1), ray.put(df2)])
assert ds._dataset_format(
) == "pandas" if enable_pandas_block else "arrow"
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
# test from single pandas dataframe ref
ds = ray.data.from_pandas_refs(ray.put(df1))
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
# test from single pandas dataframe ref
ds = ray.data.from_pandas_refs(ray.put(df1))
assert ds._dataset_format(
) == "pandas" if enable_pandas_block else "arrow"
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
finally:
ctx.enable_pandas_block = old_enable_pandas_block
def test_from_numpy(ray_start_regular_shared):
@ -1294,7 +1316,7 @@ def test_to_arrow_refs(ray_start_regular_shared):
assert df.equals(dfds)
# Conversion.
df = pd.DataFrame({0: list(range(n))})
df = pd.DataFrame({"value": list(range(n))})
ds = ray.data.range(n)
dfds = pd.concat(
[t.to_pandas() for t in ray.get(ds.to_arrow_refs())],
@ -1677,8 +1699,8 @@ def test_parquet_write_with_udf(ray_start_regular_shared, tmp_path):
df = pd.concat([df1, df2])
ds = ray.data.from_pandas([df1, df2])
def _block_udf(block: pa.Table):
df = block.to_pandas()
def _block_udf(block):
df = BlockAccessor.for_block(block).to_pandas().copy()
df["one"] += 1
return pa.Table.from_pandas(df)
@ -1865,7 +1887,7 @@ def test_iter_batches_basic(ray_start_regular_shared):
# blocks format.
for batch, df in zip(ds.iter_batches(batch_format="native"), dfs):
assert batch.to_pandas().equals(df)
assert BlockAccessor.for_block(batch).to_pandas().equals(df)
# Batch size.
batch_size = 2
@ -2027,8 +2049,10 @@ def test_map_batch(ray_start_regular_shared, tmp_path):
table = pa.Table.from_pandas(df)
pq.write_table(table, os.path.join(tmp_path, "test1.parquet"))
ds = ray.data.read_parquet(str(tmp_path))
ds_list = ds.map_batches(
lambda df: df + 1, batch_size=1, batch_format="pandas").take()
ds2 = ds.map_batches(
lambda df: df + 1, batch_size=1, batch_format="pandas")
assert ds2._dataset_format() == "pandas"
ds_list = ds2.take()
values = [s["one"] for s in ds_list]
assert values == [2, 3, 4]
values = [s["two"] for s in ds_list]
@ -2036,8 +2060,9 @@ def test_map_batch(ray_start_regular_shared, tmp_path):
# Test Pyarrow
ds = ray.data.read_parquet(str(tmp_path))
ds_list = ds.map_batches(
lambda pa: pa, batch_size=1, batch_format="pyarrow").take()
ds2 = ds.map_batches(lambda pa: pa, batch_size=1, batch_format="pyarrow")
assert ds2._dataset_format() == "arrow"
ds_list = ds2.take()
values = [s["one"] for s in ds_list]
assert values == [1, 2, 3]
values = [s["two"] for s in ds_list]
@ -2046,27 +2071,31 @@ def test_map_batch(ray_start_regular_shared, tmp_path):
# Test batch
size = 300
ds = ray.data.range(size)
ds_list = ds.map_batches(
lambda df: df + 1, batch_size=17,
batch_format="pandas").take(limit=size)
ds2 = ds.map_batches(
lambda df: df + 1, batch_size=17, batch_format="pandas")
assert ds2._dataset_format() == "pandas"
ds_list = ds2.take(limit=size)
for i in range(size):
# The pandas column is "0", and it originally has rows from 0~299.
# The pandas column is "value", and it originally has rows from 0~299.
# After the map batch, it should have 1~300.
row = ds_list[i]
assert row["0"] == i + 1
assert row["value"] == i + 1
assert ds.count() == 300
# Test the lambda returns different types than the batch_format
# pandas => list block
ds = ray.data.read_parquet(str(tmp_path))
ds_list = ds.map_batches(lambda df: [1], batch_size=1).take()
ds2 = ds.map_batches(lambda df: [1], batch_size=1)
assert ds2._dataset_format() == "simple"
ds_list = ds2.take()
assert ds_list == [1, 1, 1]
assert ds.count() == 3
# pyarrow => list block
ds = ray.data.read_parquet(str(tmp_path))
ds_list = ds.map_batches(
lambda df: [1], batch_size=1, batch_format="pyarrow").take()
ds2 = ds.map_batches(lambda df: [1], batch_size=1, batch_format="pyarrow")
assert ds2._dataset_format() == "simple"
ds_list = ds2.take()
assert ds_list == [1, 1, 1]
assert ds.count() == 3
@ -3633,6 +3662,17 @@ def test_sort_simple(ray_start_regular_shared):
assert ds.count() == 0
def test_column_name_type_check(ray_start_regular_shared):
df = pd.DataFrame({"1": np.random.rand(10), "a": np.random.rand(10)})
ds = ray.data.from_pandas(df)
expected_str = ("Dataset(num_blocks=1, num_rows=10, "
"schema={1: float64, a: float64})")
assert str(ds) == expected_str, str(ds)
df = pd.DataFrame({1: np.random.rand(10), "a": np.random.rand(10)})
with pytest.raises(ValueError):
ray.data.from_pandas(df)
@pytest.mark.parametrize("pipelined", [False, True])
def test_random_shuffle(shutdown_only, pipelined):
def range(n, parallelism=200):