mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Revert "[Dataset] [DataFrame 2/n] Add pandas block format implementation (partial) (#20988) (#21661)
This reverts commit 4a55d10bb1
.
This commit is contained in:
parent
1315293dd8
commit
fa5c167717
10 changed files with 94 additions and 381 deletions
|
@ -273,7 +273,6 @@ def inference(dataset, model_cls: type, batch_size: int, result_path: str,
|
|||
model_cls,
|
||||
compute="actors",
|
||||
batch_size=batch_size,
|
||||
batch_format="pandas",
|
||||
num_gpus=num_gpus,
|
||||
num_cpus=0) \
|
||||
.write_parquet(result_path)
|
||||
|
@ -579,8 +578,8 @@ if __name__ == "__main__":
|
|||
read_dataset(data_path))
|
||||
|
||||
num_columns = len(train_dataset.schema().names)
|
||||
# remove label column.
|
||||
num_features = num_columns - 1
|
||||
# remove label column and internal Arrow column.
|
||||
num_features = num_columns - 2
|
||||
|
||||
NUM_EPOCHS = 2
|
||||
BATCH_SIZE = 512
|
||||
|
@ -682,9 +681,9 @@ if __name__ == "__main__":
|
|||
self.model = load_model_func().to(self.device)
|
||||
|
||||
def __call__(self, batch) -> "pd.DataFrame":
|
||||
tensor = torch.FloatTensor(batch.values).to(self.device)
|
||||
return pd.DataFrame(
|
||||
self.model(tensor).cpu().detach().numpy(), columns=["value"])
|
||||
tensor = torch.FloatTensor(batch.to_pandas().values).to(
|
||||
self.device)
|
||||
return pd.DataFrame(self.model(tensor).cpu().detach().numpy())
|
||||
|
||||
inference_dataset = preprocessor.preprocess_inference_data(
|
||||
read_dataset(inference_path))
|
||||
|
|
|
@ -25,7 +25,7 @@ AggType = TypeVar("AggType")
|
|||
#
|
||||
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
|
||||
# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``.
|
||||
Block = Union[List[T], "pyarrow.Table", "pandas.DataFrame", bytes]
|
||||
Block = Union[List[T], "pyarrow.Table", bytes]
|
||||
|
||||
# A list of block references pending computation by a single task. For example,
|
||||
# this may be the output of a task reading a file.
|
||||
|
@ -196,16 +196,11 @@ class BlockAccessor(Generic[T]):
|
|||
"""Create a block accessor for the given block."""
|
||||
_check_pyarrow_version()
|
||||
import pyarrow
|
||||
import pandas
|
||||
|
||||
if isinstance(block, pyarrow.Table):
|
||||
from ray.data.impl.arrow_block import \
|
||||
ArrowBlockAccessor
|
||||
return ArrowBlockAccessor(block)
|
||||
elif isinstance(block, pandas.DataFrame):
|
||||
from ray.data.impl.pandas_block import \
|
||||
PandasBlockAccessor
|
||||
return PandasBlockAccessor(block)
|
||||
elif isinstance(block, bytes):
|
||||
from ray.data.impl.arrow_block import \
|
||||
ArrowBlockAccessor
|
||||
|
|
|
@ -14,10 +14,6 @@ DEFAULT_TARGET_MAX_BLOCK_SIZE = 2048 * 1024 * 1024
|
|||
# Whether block splitting is on by default
|
||||
DEFAULT_BLOCK_SPLITTING_ENABLED = False
|
||||
|
||||
# Whether pandas block format is enabled.
|
||||
# TODO (kfstorm): Remove this once stable.
|
||||
DEFAULT_ENABLE_PANDAS_BLOCK = True
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class DatasetContext:
|
||||
|
@ -27,18 +23,12 @@ class DatasetContext:
|
|||
from the driver and remote workers via DatasetContext.get_current().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_owner: ray.actor.ActorHandle,
|
||||
block_splitting_enabled: bool,
|
||||
target_max_block_size: int,
|
||||
enable_pandas_block: bool,
|
||||
):
|
||||
def __init__(self, block_owner: ray.actor.ActorHandle,
|
||||
block_splitting_enabled: bool, target_max_block_size: int):
|
||||
"""Private constructor (use get_current() instead)."""
|
||||
self.block_owner = block_owner
|
||||
self.block_splitting_enabled = block_splitting_enabled
|
||||
self.target_max_block_size = target_max_block_size
|
||||
self.enable_pandas_block = enable_pandas_block
|
||||
|
||||
@staticmethod
|
||||
def get_current() -> "DatasetContext":
|
||||
|
@ -55,9 +45,7 @@ class DatasetContext:
|
|||
_default_context = DatasetContext(
|
||||
block_owner=None,
|
||||
block_splitting_enabled=DEFAULT_BLOCK_SPLITTING_ENABLED,
|
||||
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE,
|
||||
enable_pandas_block=DEFAULT_ENABLE_PANDAS_BLOCK,
|
||||
)
|
||||
target_max_block_size=DEFAULT_TARGET_MAX_BLOCK_SIZE)
|
||||
|
||||
if _default_context.block_owner is None:
|
||||
owner = _DesignatedBlockOwner.options(
|
||||
|
|
|
@ -44,7 +44,6 @@ from ray.data.impl.shuffle import simple_shuffle, _shuffle_reduce
|
|||
from ray.data.impl.sort import sort_impl
|
||||
from ray.data.impl.block_list import BlockList
|
||||
from ray.data.impl.lazy_block_list import LazyBlockList
|
||||
from ray.data.impl.table_block import TableRow
|
||||
from ray.data.impl.delegating_block_builder import DelegatingBlockBuilder
|
||||
|
||||
# An output type of iter_batches() determined by the batch_format parameter.
|
||||
|
@ -231,9 +230,11 @@ class Dataset(Generic[T]):
|
|||
"or 'pyarrow', got: {}".format(batch_format))
|
||||
|
||||
applied = fn(view)
|
||||
if not (isinstance(applied, list)
|
||||
or isinstance(applied, pa.Table)
|
||||
or isinstance(applied, pd.core.frame.DataFrame)):
|
||||
if isinstance(applied, list) or isinstance(applied, pa.Table):
|
||||
applied = applied
|
||||
elif isinstance(applied, pd.core.frame.DataFrame):
|
||||
applied = pa.Table.from_pandas(applied)
|
||||
else:
|
||||
raise ValueError("The map batches UDF returned the value "
|
||||
f"{applied}, which is not allowed. "
|
||||
"The return type must be either list, "
|
||||
|
@ -402,15 +403,12 @@ class Dataset(Generic[T]):
|
|||
# Handle empty blocks.
|
||||
if len(new_blocks) < num_blocks:
|
||||
from ray.data.impl.arrow_block import ArrowBlockBuilder
|
||||
from ray.data.impl.pandas_block import PandasBlockBuilder
|
||||
from ray.data.impl.simple_block import SimpleBlockBuilder
|
||||
|
||||
num_empties = num_blocks - len(new_blocks)
|
||||
dataset_format = self._dataset_format()
|
||||
if dataset_format == "arrow":
|
||||
builder = ArrowBlockBuilder()
|
||||
elif dataset_format == "pandas":
|
||||
builder = PandasBlockBuilder()
|
||||
else:
|
||||
builder = SimpleBlockBuilder()
|
||||
empty_block = builder.build()
|
||||
|
@ -941,7 +939,7 @@ class Dataset(Generic[T]):
|
|||
# Dataset is empty/cleared, let downstream ops handle this.
|
||||
return on
|
||||
|
||||
if dataset_format == "arrow" or dataset_format == "pandas":
|
||||
if dataset_format == "arrow":
|
||||
# This should be cached from the ._dataset_format() check, so we
|
||||
# don't fetch and we assert that the schema is not None.
|
||||
schema = self.schema(fetch_if_missing=False)
|
||||
|
@ -974,34 +972,31 @@ class Dataset(Generic[T]):
|
|||
and isinstance(on[0], str)):
|
||||
raise ValueError(
|
||||
"Can't aggregate on a column when using a simple Dataset; "
|
||||
"use a callable `on` argument or use an Arrow or Pandas"
|
||||
" Dataset instead of a simple Dataset.")
|
||||
"use a callable `on` argument or use an Arrow Dataset "
|
||||
"instead of a simple Dataset.")
|
||||
return on
|
||||
|
||||
def _dataset_format(self) -> str:
|
||||
"""Determine the format of the dataset. Possible values are: "arrow",
|
||||
"pandas", "simple".
|
||||
"simple".
|
||||
|
||||
This may block; if the schema is unknown, this will synchronously fetch
|
||||
the schema for the first block.
|
||||
"""
|
||||
try:
|
||||
import pyarrow as pa
|
||||
except ModuleNotFoundError:
|
||||
return "simple"
|
||||
else:
|
||||
# We need schema to properly validate, so synchronously
|
||||
# fetch it if necessary.
|
||||
schema = self.schema(fetch_if_missing=True)
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"Dataset is empty or cleared, can't determine the format of "
|
||||
"the dataset.")
|
||||
|
||||
try:
|
||||
import pyarrow as pa
|
||||
"Dataset is empty or cleared, can't determine the format"
|
||||
" of the dataset")
|
||||
if isinstance(schema, pa.Schema):
|
||||
return "arrow"
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
from ray.data.impl.pandas_block import PandasBlockSchema
|
||||
if isinstance(schema, PandasBlockSchema):
|
||||
return "pandas"
|
||||
return "simple"
|
||||
|
||||
def _aggregate_on(self, agg_cls: type, on: Optional["AggregateOnTs"],
|
||||
|
@ -1031,18 +1026,6 @@ class Dataset(Generic[T]):
|
|||
on = [on]
|
||||
return [agg_cls(on_, *args, **kwargs) for on_ in on]
|
||||
|
||||
def _aggregate_result(self, result: Union[Tuple, TableRow]) -> U:
|
||||
if len(result) == 1:
|
||||
if isinstance(result, tuple):
|
||||
return result[0]
|
||||
else:
|
||||
# NOTE (kfstorm): We cannot call `result[0]` directly on
|
||||
# `PandasRow` because indexing a column with position is not
|
||||
# supported by pandas.
|
||||
return list(result.values())[0]
|
||||
else:
|
||||
return result
|
||||
|
||||
def sum(self, on: Optional["AggregateOnTs"] = None) -> U:
|
||||
"""Compute sum over entire dataset.
|
||||
|
||||
|
@ -1093,8 +1076,10 @@ class Dataset(Generic[T]):
|
|||
ret = self._aggregate_on(Sum, on)
|
||||
if ret is None:
|
||||
return 0
|
||||
elif len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
return ret
|
||||
|
||||
def min(self, on: Optional["AggregateOnTs"] = None) -> U:
|
||||
"""Compute minimum over entire dataset.
|
||||
|
@ -1146,8 +1131,10 @@ class Dataset(Generic[T]):
|
|||
ret = self._aggregate_on(Min, on)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute min on an empty dataset")
|
||||
elif len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
return ret
|
||||
|
||||
def max(self, on: Optional["AggregateOnTs"] = None) -> U:
|
||||
"""Compute maximum over entire dataset.
|
||||
|
@ -1199,8 +1186,10 @@ class Dataset(Generic[T]):
|
|||
ret = self._aggregate_on(Max, on)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute max on an empty dataset")
|
||||
elif len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
return ret
|
||||
|
||||
def mean(self, on: Optional["AggregateOnTs"] = None) -> U:
|
||||
"""Compute mean over entire dataset.
|
||||
|
@ -1252,8 +1241,10 @@ class Dataset(Generic[T]):
|
|||
ret = self._aggregate_on(Mean, on)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute mean on an empty dataset")
|
||||
elif len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
return ret
|
||||
|
||||
def std(self, on: Optional["AggregateOnTs"] = None, ddof: int = 1) -> U:
|
||||
"""Compute standard deviation over entire dataset.
|
||||
|
@ -1315,8 +1306,10 @@ class Dataset(Generic[T]):
|
|||
ret = self._aggregate_on(Std, on, ddof=ddof)
|
||||
if ret is None:
|
||||
raise ValueError("Cannot compute std on an empty dataset")
|
||||
elif len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return self._aggregate_result(ret)
|
||||
return ret
|
||||
|
||||
def sort(self,
|
||||
key: Union[None, str, List[str], Callable[[T], Any]] = None,
|
||||
|
@ -2271,10 +2264,10 @@ Dict[str, List[str]]]): The names of the columns
|
|||
def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame":
|
||||
"""Convert this dataset into a single Pandas DataFrame.
|
||||
|
||||
This is only supported for datasets convertible to Arrow or Pandas
|
||||
records. An error is raised if the number of records exceeds the
|
||||
provided limit. Note that you can use ``.limit()`` on the dataset
|
||||
beforehand to truncate the dataset manually.
|
||||
This is only supported for datasets convertible to Arrow records. An
|
||||
error is raised if the number of records exceeds the provided limit.
|
||||
Note that you can use ``.limit()`` on the dataset beforehand to
|
||||
truncate the dataset manually.
|
||||
|
||||
Time complexity: O(dataset size)
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ class ArrowBlockAccessor(TableBlockAccessor):
|
|||
view = _copy_table(view)
|
||||
return view
|
||||
|
||||
def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
|
||||
def random_shuffle(self, random_seed: Optional[int]) -> List[T]:
|
||||
random = np.random.RandomState(random_seed)
|
||||
return self._table.take(random.permutation(self.num_rows()))
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ from ray.data.block import Block, T, BlockAccessor
|
|||
from ray.data.impl.block_builder import BlockBuilder
|
||||
from ray.data.impl.simple_block import SimpleBlockBuilder
|
||||
from ray.data.impl.arrow_block import ArrowRow, ArrowBlockBuilder
|
||||
from ray.data.impl.pandas_block import PandasRow, PandasBlockBuilder
|
||||
|
||||
|
||||
class DelegatingBlockBuilder(BlockBuilder[T]):
|
||||
|
@ -14,7 +13,6 @@ class DelegatingBlockBuilder(BlockBuilder[T]):
|
|||
def add(self, item: Any) -> None:
|
||||
|
||||
if self._builder is None:
|
||||
# TODO (kfstorm): Maybe we can use Pandas block format for dict.
|
||||
if isinstance(item, dict) or isinstance(item, ArrowRow):
|
||||
import pyarrow
|
||||
try:
|
||||
|
@ -24,8 +22,6 @@ class DelegatingBlockBuilder(BlockBuilder[T]):
|
|||
self._builder = ArrowBlockBuilder()
|
||||
except (TypeError, pyarrow.lib.ArrowInvalid):
|
||||
self._builder = SimpleBlockBuilder()
|
||||
elif isinstance(item, PandasRow):
|
||||
self._builder = PandasBlockBuilder()
|
||||
else:
|
||||
self._builder = SimpleBlockBuilder()
|
||||
self._builder.add(item)
|
||||
|
|
|
@ -1,188 +0,0 @@
|
|||
from typing import Dict, List, Tuple, Any, TypeVar, Optional, TYPE_CHECKING
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import pandas
|
||||
except ImportError:
|
||||
pandas = None
|
||||
|
||||
from ray.data.block import BlockAccessor, BlockMetadata
|
||||
from ray.data.impl.table_block import TableBlockAccessor, TableRow, \
|
||||
TableBlockBuilder, SortKeyT, GroupKeyT
|
||||
from ray.data.impl.arrow_block import ArrowBlockAccessor
|
||||
from ray.data.aggregate import AggregateFn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pyarrow
|
||||
import pandas
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PandasRow(TableRow):
|
||||
def as_pydict(self) -> dict:
|
||||
return {k: v[0] for k, v in self._row.to_dict("list").items()}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
assert isinstance(key, str)
|
||||
col = self._row[key]
|
||||
if len(col) == 0:
|
||||
return None
|
||||
item = col.iloc[0]
|
||||
try:
|
||||
# Try to interpret this as a numpy-type value.
|
||||
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
|
||||
return item.item()
|
||||
except AttributeError:
|
||||
# Fallback to the original form.
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self._row.shape[1]
|
||||
|
||||
|
||||
class PandasBlockBuilder(TableBlockBuilder[T]):
|
||||
def __init__(self):
|
||||
if pandas is None:
|
||||
raise ImportError("Run `pip install pandas` for Pandas support.")
|
||||
super().__init__(pandas.DataFrame)
|
||||
|
||||
def _table_from_pydict(
|
||||
self, columns: Dict[str, List[Any]]) -> "pandas.DataFrame":
|
||||
return pandas.DataFrame(columns)
|
||||
|
||||
def _concat_tables(self,
|
||||
tables: List["pandas.DataFrame"]) -> "pandas.DataFrame":
|
||||
return pandas.concat(tables, ignore_index=True)
|
||||
|
||||
@staticmethod
|
||||
def _empty_table() -> "pandas.DataFrame":
|
||||
return pandas.DataFrame()
|
||||
|
||||
|
||||
# This is to be compatible with pyarrow.lib.schema
|
||||
# TODO (kfstorm): We need a format-independent way to represent schema.
|
||||
PandasBlockSchema = collections.namedtuple("PandasBlockSchema",
|
||||
["names", "types"])
|
||||
|
||||
|
||||
class PandasBlockAccessor(TableBlockAccessor):
|
||||
def __init__(self, table: "pandas.DataFrame"):
|
||||
if pandas is None:
|
||||
raise ImportError("Run `pip install pandas` for Pandas support.")
|
||||
super().__init__(table)
|
||||
|
||||
def _create_table_row(self, row: "pandas.DataFrame") -> PandasRow:
|
||||
return PandasRow(row)
|
||||
|
||||
def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame":
|
||||
view = self._table[start:end]
|
||||
if copy:
|
||||
view = view.copy(deep=True)
|
||||
return view
|
||||
|
||||
def random_shuffle(self, random_seed: Optional[int]) -> "pandas.DataFrame":
|
||||
return self._table.sample(frac=1, random_state=random_seed)
|
||||
|
||||
def schema(self) -> PandasBlockSchema:
|
||||
dtypes = self._table.dtypes
|
||||
schema = PandasBlockSchema(
|
||||
names=dtypes.index.tolist(), types=dtypes.values.tolist())
|
||||
# Column names with non-str types of a pandas DataFrame is not
|
||||
# supported by Ray Dataset.
|
||||
if any(not isinstance(name, str) for name in schema.names):
|
||||
raise ValueError(
|
||||
"A Pandas DataFrame with column names of non-str types"
|
||||
" is not supported by Ray Dataset. Column names of this"
|
||||
f" DataFrame: {schema.names!r}.")
|
||||
return schema
|
||||
|
||||
def to_pandas(self) -> "pandas.DataFrame":
|
||||
return self._table
|
||||
|
||||
def to_numpy(self, column: str = None) -> np.ndarray:
|
||||
if not column:
|
||||
raise ValueError(
|
||||
"`column` must be specified when calling .to_numpy() "
|
||||
"on Pandas blocks.")
|
||||
if column not in self._table.columns:
|
||||
raise ValueError(
|
||||
"Cannot find column {}, available columns: {}".format(
|
||||
column, self._table.columns.tolist()))
|
||||
return self._table[column].to_numpy()
|
||||
|
||||
def to_arrow(self) -> "pyarrow.Table":
|
||||
import pyarrow
|
||||
return pyarrow.table(self._table)
|
||||
|
||||
def num_rows(self) -> int:
|
||||
return self._table.shape[0]
|
||||
|
||||
def size_bytes(self) -> int:
|
||||
return self._table.memory_usage(index=True, deep=True).sum()
|
||||
|
||||
def _zip(self, acc: BlockAccessor) -> "pandas.DataFrame":
|
||||
r = self.to_pandas().copy(deep=False)
|
||||
s = acc.to_pandas()
|
||||
for col_name in s.columns:
|
||||
col = s[col_name]
|
||||
# Ensure the column names are unique after zip.
|
||||
if col_name in r.column_names:
|
||||
i = 1
|
||||
new_name = col_name
|
||||
while new_name in r.column_names:
|
||||
new_name = "{}_{}".format(col_name, i)
|
||||
i += 1
|
||||
col_name = new_name
|
||||
r[col_name] = col
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def builder() -> PandasBlockBuilder[T]:
|
||||
return PandasBlockBuilder()
|
||||
|
||||
@staticmethod
|
||||
def _empty_table() -> "pandas.DataFrame":
|
||||
return PandasBlockBuilder._empty_table()
|
||||
|
||||
def _sample(self, n_samples: int, key: SortKeyT) -> "pandas.DataFrame":
|
||||
return self._table[[k[0] for k in key]].sample(
|
||||
n_samples, ignore_index=True)
|
||||
|
||||
def sort_and_partition(self, boundaries: List[T], key: SortKeyT,
|
||||
descending: bool) -> List["pandas.DataFrame"]:
|
||||
# TODO (kfstorm): A workaround to pass tests. Not efficient.
|
||||
delegated_result = BlockAccessor.for_block(
|
||||
self.to_arrow()).sort_and_partition(boundaries, key, descending)
|
||||
return [
|
||||
BlockAccessor.for_block(_).to_pandas() for _ in delegated_result
|
||||
]
|
||||
|
||||
def combine(self, key: GroupKeyT,
|
||||
aggs: Tuple[AggregateFn]) -> "pandas.DataFrame":
|
||||
# TODO (kfstorm): A workaround to pass tests. Not efficient.
|
||||
return BlockAccessor.for_block(self.to_arrow()).combine(
|
||||
key, aggs).to_pandas()
|
||||
|
||||
@staticmethod
|
||||
def merge_sorted_blocks(
|
||||
blocks: List["pandas.DataFrame"], key: SortKeyT,
|
||||
_descending: bool) -> Tuple["pandas.DataFrame", BlockMetadata]:
|
||||
# TODO (kfstorm): A workaround to pass tests. Not efficient.
|
||||
block, metadata = ArrowBlockAccessor.merge_sorted_blocks(
|
||||
[BlockAccessor.for_block(block).to_arrow() for block in blocks],
|
||||
key, _descending)
|
||||
return BlockAccessor.for_block(block).to_pandas(), metadata
|
||||
|
||||
@staticmethod
|
||||
def aggregate_combined_blocks(
|
||||
blocks: List["pandas.DataFrame"], key: GroupKeyT,
|
||||
aggs: Tuple[AggregateFn]
|
||||
) -> Tuple["pandas.DataFrame", BlockMetadata]:
|
||||
# TODO (kfstorm): A workaround to pass tests. Not efficient.
|
||||
block, metadata = ArrowBlockAccessor.aggregate_combined_blocks(
|
||||
[BlockAccessor.for_block(block).to_arrow() for block in blocks],
|
||||
key, aggs)
|
||||
return BlockAccessor.for_block(block).to_pandas(), metadata
|
|
@ -71,7 +71,7 @@ class SimpleBlockAccessor(BlockAccessor):
|
|||
|
||||
def to_pandas(self) -> "pandas.DataFrame":
|
||||
import pandas
|
||||
return pandas.DataFrame({"value": self._items})
|
||||
return pandas.DataFrame(self._items)
|
||||
|
||||
def to_numpy(self, column: str = None) -> np.ndarray:
|
||||
if column:
|
||||
|
|
|
@ -534,22 +534,8 @@ def from_dask(df: "dask.DataFrame") -> Dataset[ArrowRow]:
|
|||
|
||||
partitions = df.to_delayed()
|
||||
persisted_partitions = dask.persist(*partitions, scheduler=ray_dask_get)
|
||||
|
||||
import pandas
|
||||
|
||||
def to_ref(df):
|
||||
if isinstance(df, pandas.DataFrame):
|
||||
return ray.put(df)
|
||||
elif isinstance(df, ray.ObjectRef):
|
||||
return df
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected a Ray object ref or a Pandas DataFrame, "
|
||||
f"got {type(df)}")
|
||||
|
||||
return from_pandas_refs([
|
||||
to_ref(next(iter(part.dask.values()))) for part in persisted_partitions
|
||||
])
|
||||
return from_pandas_refs(
|
||||
[next(iter(part.dask.values())) for part in persisted_partitions])
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
@ -614,21 +600,6 @@ def from_pandas_refs(dfs: Union[ObjectRef["pandas.DataFrame"], List[ObjectRef[
|
|||
"""
|
||||
if isinstance(dfs, ray.ObjectRef):
|
||||
dfs = [dfs]
|
||||
elif isinstance(dfs, list):
|
||||
for df in dfs:
|
||||
if not isinstance(df, ray.ObjectRef):
|
||||
raise ValueError("Expected list of Ray object refs, "
|
||||
f"got list containing {type(df)}")
|
||||
else:
|
||||
raise ValueError("Expected Ray object ref or list of Ray object refs, "
|
||||
f"got {type(df)}")
|
||||
|
||||
context = DatasetContext.get_current()
|
||||
if context.enable_pandas_block:
|
||||
get_metadata = cached_remote_fn(_get_metadata)
|
||||
metadata = [get_metadata.remote(df) for df in dfs]
|
||||
return Dataset(
|
||||
BlockList(dfs, ray.get(metadata)), 0, DatasetStats.TODO())
|
||||
|
||||
df_to_block = cached_remote_fn(_df_to_block, num_returns=2)
|
||||
|
||||
|
@ -736,8 +707,7 @@ def _ndarray_to_block(ndarray: np.ndarray) -> Block[np.ndarray]:
|
|||
input_files=None, exec_stats=stats.build()))
|
||||
|
||||
|
||||
def _get_metadata(
|
||||
table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMetadata:
|
||||
def _get_metadata(table: "pyarrow.Table") -> BlockMetadata:
|
||||
stats = BlockExecStats.builder()
|
||||
return BlockAccessor.for_block(table).get_metadata(
|
||||
input_files=None, exec_stats=stats.build())
|
||||
|
|
|
@ -329,7 +329,7 @@ def test_batch_tensors(ray_start_regular_shared):
|
|||
with pytest.raises(pa.lib.ArrowInvalid):
|
||||
next(ds.iter_batches(batch_format="pyarrow"))
|
||||
df = next(ds.iter_batches(batch_format="pandas"))
|
||||
assert df.to_dict().keys() == {"value"}
|
||||
assert df.to_dict().keys() == {0, 1}
|
||||
|
||||
|
||||
def test_arrow_block_slice_copy():
|
||||
|
@ -1156,56 +1156,34 @@ def test_repartition_shuffle_arrow(ray_start_regular_shared):
|
|||
assert large._block_num_rows() == [500] * 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_pandas_block", [False, True])
|
||||
def test_from_pandas(ray_start_regular_shared, enable_pandas_block):
|
||||
ctx = ray.data.context.DatasetContext.get_current()
|
||||
old_enable_pandas_block = ctx.enable_pandas_block
|
||||
ctx.enable_pandas_block = enable_pandas_block
|
||||
try:
|
||||
def test_from_pandas(ray_start_regular_shared):
|
||||
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
|
||||
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
|
||||
ds = ray.data.from_pandas([df1, df2])
|
||||
assert ds._dataset_format(
|
||||
) == "pandas" if enable_pandas_block else "arrow"
|
||||
values = [(r["one"], r["two"]) for r in ds.take(6)]
|
||||
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
|
||||
assert values == rows
|
||||
|
||||
# test from single pandas dataframe
|
||||
ds = ray.data.from_pandas(df1)
|
||||
assert ds._dataset_format(
|
||||
) == "pandas" if enable_pandas_block else "arrow"
|
||||
values = [(r["one"], r["two"]) for r in ds.take(3)]
|
||||
rows = [(r.one, r.two) for _, r in df1.iterrows()]
|
||||
assert values == rows
|
||||
finally:
|
||||
ctx.enable_pandas_block = old_enable_pandas_block
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_pandas_block", [False, True])
|
||||
def test_from_pandas_refs(ray_start_regular_shared, enable_pandas_block):
|
||||
ctx = ray.data.context.DatasetContext.get_current()
|
||||
old_enable_pandas_block = ctx.enable_pandas_block
|
||||
ctx.enable_pandas_block = enable_pandas_block
|
||||
try:
|
||||
def test_from_pandas_refs(ray_start_regular_shared):
|
||||
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
|
||||
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
|
||||
ds = ray.data.from_pandas_refs([ray.put(df1), ray.put(df2)])
|
||||
assert ds._dataset_format(
|
||||
) == "pandas" if enable_pandas_block else "arrow"
|
||||
values = [(r["one"], r["two"]) for r in ds.take(6)]
|
||||
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
|
||||
assert values == rows
|
||||
|
||||
# test from single pandas dataframe ref
|
||||
ds = ray.data.from_pandas_refs(ray.put(df1))
|
||||
assert ds._dataset_format(
|
||||
) == "pandas" if enable_pandas_block else "arrow"
|
||||
values = [(r["one"], r["two"]) for r in ds.take(3)]
|
||||
rows = [(r.one, r.two) for _, r in df1.iterrows()]
|
||||
assert values == rows
|
||||
finally:
|
||||
ctx.enable_pandas_block = old_enable_pandas_block
|
||||
|
||||
|
||||
def test_from_numpy(ray_start_regular_shared):
|
||||
|
@ -1316,7 +1294,7 @@ def test_to_arrow_refs(ray_start_regular_shared):
|
|||
assert df.equals(dfds)
|
||||
|
||||
# Conversion.
|
||||
df = pd.DataFrame({"value": list(range(n))})
|
||||
df = pd.DataFrame({0: list(range(n))})
|
||||
ds = ray.data.range(n)
|
||||
dfds = pd.concat(
|
||||
[t.to_pandas() for t in ray.get(ds.to_arrow_refs())],
|
||||
|
@ -1699,8 +1677,8 @@ def test_parquet_write_with_udf(ray_start_regular_shared, tmp_path):
|
|||
df = pd.concat([df1, df2])
|
||||
ds = ray.data.from_pandas([df1, df2])
|
||||
|
||||
def _block_udf(block):
|
||||
df = BlockAccessor.for_block(block).to_pandas().copy()
|
||||
def _block_udf(block: pa.Table):
|
||||
df = block.to_pandas()
|
||||
df["one"] += 1
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
|
@ -1887,7 +1865,7 @@ def test_iter_batches_basic(ray_start_regular_shared):
|
|||
|
||||
# blocks format.
|
||||
for batch, df in zip(ds.iter_batches(batch_format="native"), dfs):
|
||||
assert BlockAccessor.for_block(batch).to_pandas().equals(df)
|
||||
assert batch.to_pandas().equals(df)
|
||||
|
||||
# Batch size.
|
||||
batch_size = 2
|
||||
|
@ -2049,10 +2027,8 @@ def test_map_batch(ray_start_regular_shared, tmp_path):
|
|||
table = pa.Table.from_pandas(df)
|
||||
pq.write_table(table, os.path.join(tmp_path, "test1.parquet"))
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
ds2 = ds.map_batches(
|
||||
lambda df: df + 1, batch_size=1, batch_format="pandas")
|
||||
assert ds2._dataset_format() == "pandas"
|
||||
ds_list = ds2.take()
|
||||
ds_list = ds.map_batches(
|
||||
lambda df: df + 1, batch_size=1, batch_format="pandas").take()
|
||||
values = [s["one"] for s in ds_list]
|
||||
assert values == [2, 3, 4]
|
||||
values = [s["two"] for s in ds_list]
|
||||
|
@ -2060,9 +2036,8 @@ def test_map_batch(ray_start_regular_shared, tmp_path):
|
|||
|
||||
# Test Pyarrow
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
ds2 = ds.map_batches(lambda pa: pa, batch_size=1, batch_format="pyarrow")
|
||||
assert ds2._dataset_format() == "arrow"
|
||||
ds_list = ds2.take()
|
||||
ds_list = ds.map_batches(
|
||||
lambda pa: pa, batch_size=1, batch_format="pyarrow").take()
|
||||
values = [s["one"] for s in ds_list]
|
||||
assert values == [1, 2, 3]
|
||||
values = [s["two"] for s in ds_list]
|
||||
|
@ -2071,31 +2046,27 @@ def test_map_batch(ray_start_regular_shared, tmp_path):
|
|||
# Test batch
|
||||
size = 300
|
||||
ds = ray.data.range(size)
|
||||
ds2 = ds.map_batches(
|
||||
lambda df: df + 1, batch_size=17, batch_format="pandas")
|
||||
assert ds2._dataset_format() == "pandas"
|
||||
ds_list = ds2.take(limit=size)
|
||||
ds_list = ds.map_batches(
|
||||
lambda df: df + 1, batch_size=17,
|
||||
batch_format="pandas").take(limit=size)
|
||||
for i in range(size):
|
||||
# The pandas column is "value", and it originally has rows from 0~299.
|
||||
# The pandas column is "0", and it originally has rows from 0~299.
|
||||
# After the map batch, it should have 1~300.
|
||||
row = ds_list[i]
|
||||
assert row["value"] == i + 1
|
||||
assert row["0"] == i + 1
|
||||
assert ds.count() == 300
|
||||
|
||||
# Test the lambda returns different types than the batch_format
|
||||
# pandas => list block
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
ds2 = ds.map_batches(lambda df: [1], batch_size=1)
|
||||
assert ds2._dataset_format() == "simple"
|
||||
ds_list = ds2.take()
|
||||
ds_list = ds.map_batches(lambda df: [1], batch_size=1).take()
|
||||
assert ds_list == [1, 1, 1]
|
||||
assert ds.count() == 3
|
||||
|
||||
# pyarrow => list block
|
||||
ds = ray.data.read_parquet(str(tmp_path))
|
||||
ds2 = ds.map_batches(lambda df: [1], batch_size=1, batch_format="pyarrow")
|
||||
assert ds2._dataset_format() == "simple"
|
||||
ds_list = ds2.take()
|
||||
ds_list = ds.map_batches(
|
||||
lambda df: [1], batch_size=1, batch_format="pyarrow").take()
|
||||
assert ds_list == [1, 1, 1]
|
||||
assert ds.count() == 3
|
||||
|
||||
|
@ -3676,17 +3647,6 @@ def test_sort_simple(ray_start_regular_shared):
|
|||
assert ds.count() == 0
|
||||
|
||||
|
||||
def test_column_name_type_check(ray_start_regular_shared):
|
||||
df = pd.DataFrame({"1": np.random.rand(10), "a": np.random.rand(10)})
|
||||
ds = ray.data.from_pandas(df)
|
||||
expected_str = ("Dataset(num_blocks=1, num_rows=10, "
|
||||
"schema={1: float64, a: float64})")
|
||||
assert str(ds) == expected_str, str(ds)
|
||||
df = pd.DataFrame({1: np.random.rand(10), "a": np.random.rand(10)})
|
||||
with pytest.raises(ValueError):
|
||||
ray.data.from_pandas(df)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipelined", [False, True])
|
||||
def test_random_shuffle(shutdown_only, pipelined):
|
||||
def range(n, parallelism=200):
|
||||
|
|
Loading…
Add table
Reference in a new issue