[dataset] Fix conversion to pyarrow tables in several transforms (#16916)

This commit is contained in:
Eric Liang 2021-07-06 20:40:57 -07:00 committed by GitHub
parent 23088bd7ea
commit ca083e16d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 109 additions and 35 deletions

View file

@ -12,6 +12,8 @@ if TYPE_CHECKING:
import dask import dask
import pyspark import pyspark
import ray.util.sgd import ray.util.sgd
import torch
import tensorflow as tf
import collections import collections
import itertools import itertools
@ -57,7 +59,9 @@ class Dataset(Generic[T]):
self._blocks: BlockList[T] = blocks self._blocks: BlockList[T] = blocks
assert isinstance(self._blocks, BlockList), self._blocks assert isinstance(self._blocks, BlockList), self._blocks
def map(self, fn: Callable[[T], U], compute="tasks", def map(self,
fn: Callable[[T], U],
compute: Optional[str] = None,
**ray_remote_args) -> "Dataset[U]": **ray_remote_args) -> "Dataset[U]":
"""Apply the given function to each record of this dataset. """Apply the given function to each record of this dataset.
@ -75,8 +79,8 @@ class Dataset(Generic[T]):
Args: Args:
fn: The function to apply to each record. fn: The function to apply to each record.
compute: The compute strategy, either "tasks" to use Ray tasks, compute: The compute strategy, either "tasks" (default) to use Ray
or "actors" to use an autoscaling Ray actor pool. tasks, or "actors" to use an autoscaling Ray actor pool.
ray_remote_args: Additional resource requirements to request from ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks). ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" """
@ -94,7 +98,7 @@ class Dataset(Generic[T]):
def map_batches(self, def map_batches(self,
fn: Callable[[BatchType], BatchType], fn: Callable[[BatchType], BatchType],
batch_size: int = None, batch_size: int = None,
compute: str = "tasks", compute: Optional[str] = None,
batch_format: str = "pandas", batch_format: str = "pandas",
**ray_remote_args) -> "Dataset[Any]": **ray_remote_args) -> "Dataset[Any]":
"""Apply the given function to batches of records of this dataset. """Apply the given function to batches of records of this dataset.
@ -126,12 +130,12 @@ class Dataset(Generic[T]):
fn: The function to apply to each record batch. fn: The function to apply to each record batch.
batch_size: Request a specific batch size, or leave unspecified batch_size: Request a specific batch size, or leave unspecified
to use entire blocks as batches. to use entire blocks as batches.
compute: The compute strategy, either "tasks" to use Ray tasks, compute: The compute strategy, either "tasks" (default) to use Ray
or "actors" to use an autoscaling Ray actor pool. When using tasks, or "actors" to use an autoscaling Ray actor pool. When
actors, state can be preserved across function invocations using actors, state can be preserved across function
in Python global variables. This can be useful for one-time invocations in Python global variables. This can be useful for
setups, e.g., initializing a model once and re-using it across one-time setups, e.g., initializing a model once and re-using
many function applications. it across many function applications.
batch_format: Specify "pandas" to select ``pandas.DataFrame`` as batch_format: Specify "pandas" to select ``pandas.DataFrame`` as
the batch format, or "pyarrow" to select ``pyarrow.Table``. the batch format, or "pyarrow" to select ``pyarrow.Table``.
ray_remote_args: Additional resource requirements to request from ray_remote_args: Additional resource requirements to request from
@ -157,7 +161,7 @@ class Dataset(Generic[T]):
if batch_format == "pandas": if batch_format == "pandas":
view = view.to_pandas() view = view.to_pandas()
elif batch_format == "pyarrow": elif batch_format == "pyarrow":
view = view._table view = view.to_arrow_table()
else: else:
raise ValueError( raise ValueError(
f"The given batch format: {batch_format} " f"The given batch format: {batch_format} "
@ -185,7 +189,7 @@ class Dataset(Generic[T]):
def flat_map(self, def flat_map(self,
fn: Callable[[T], Iterable[U]], fn: Callable[[T], Iterable[U]],
compute="tasks", compute: Optional[str] = None,
**ray_remote_args) -> "Dataset[U]": **ray_remote_args) -> "Dataset[U]":
"""Apply the given function to each record and then flatten results. """Apply the given function to each record and then flatten results.
@ -199,8 +203,8 @@ class Dataset(Generic[T]):
Args: Args:
fn: The function to apply to each record. fn: The function to apply to each record.
compute: The compute strategy, either "tasks" to use Ray tasks, compute: The compute strategy, either "tasks" (default) to use Ray
or "actors" to use an autoscaling Ray actor pool. tasks, or "actors" to use an autoscaling Ray actor pool.
ray_remote_args: Additional resource requirements to request from ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks). ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" """
@ -218,7 +222,7 @@ class Dataset(Generic[T]):
def filter(self, def filter(self,
fn: Callable[[T], bool], fn: Callable[[T], bool],
compute="tasks", compute: Optional[str] = None,
**ray_remote_args) -> "Dataset[T]": **ray_remote_args) -> "Dataset[T]":
"""Filter out records that do not satisfy the given predicate. """Filter out records that do not satisfy the given predicate.
@ -232,8 +236,8 @@ class Dataset(Generic[T]):
Args: Args:
fn: The predicate function to apply to each record. fn: The predicate function to apply to each record.
compute: The compute strategy, either "tasks" to use Ray tasks, compute: The compute strategy, either "tasks" (default) to use Ray
or "actors" to use an autoscaling Ray actor pool. tasks, or "actors" to use an autoscaling Ray actor pool.
ray_remote_args: Additional resource requirements to request from ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks). ray (e.g., num_gpus=1 to request GPUs for the map tasks).
""" """
@ -654,8 +658,9 @@ class Dataset(Generic[T]):
def parquet_write(write_path, block): def parquet_write(write_path, block):
logger.debug( logger.debug(
f"Writing {block.num_rows()} records to {write_path}.") f"Writing {block.num_rows()} records to {write_path}.")
with pq.ParquetWriter(write_path, block._table.schema) as writer: table = block.to_arrow_table()
writer.write_table(block._table) with pq.ParquetWriter(write_path, table.schema) as writer:
writer.write_table(table)
refs = [ refs = [
parquet_write.remote( parquet_write.remote(
@ -808,23 +813,29 @@ class Dataset(Generic[T]):
raise NotImplementedError # P1 raise NotImplementedError # P1
def to_torch(self, **todo) -> "ray.util.sgd.torch.TorchMLDataset": def to_torch(self, **todo) -> "torch.utils.data.IterableDataset":
"""Return a dataset that can be used for Torch distributed training. """Return a Torch data iterator over this dataset.
Note that you probably want to call ``.split()`` on this dataset if
there are to be multiple Torch workers consuming the data.
Time complexity: O(1) Time complexity: O(1)
Returns: Returns:
A TorchMLDataset. A torch IterableDataset.
""" """
raise NotImplementedError # P1 raise NotImplementedError # P1
def to_tf(self, **todo) -> "ray.util.sgd.tf.TFMLDataset": def to_tf(self, **todo) -> "tf.data.Dataset":
"""Return a dataset that can be used for TF distributed training. """Return a TF data iterator over this dataset.
Note that you probably want to call ``.split()`` on this dataset if
there are to be multiple TensorFlow workers consuming the data.
Time complexity: O(1) Time complexity: O(1)
Returns: Returns:
A TFMLDataset. A tf.data.Dataset.
""" """
raise NotImplementedError # P1 raise NotImplementedError # P1
@ -855,7 +866,7 @@ class Dataset(Generic[T]):
"Dataset.to_dask() must be used with Dask-on-Ray, please " "Dataset.to_dask() must be used with Dask-on-Ray, please "
"set the Dask scheduler to ray_dask_get (located in " "set the Dask scheduler to ray_dask_get (located in "
"ray.util.dask).") "ray.util.dask).")
return block._table.to_pandas() return block.to_pandas()
# TODO(Clark): Give Dask a Pandas-esque schema via the Pyarrow schema, # TODO(Clark): Give Dask a Pandas-esque schema via the Pyarrow schema,
# once that's implemented. # once that's implemented.
@ -895,7 +906,7 @@ class Dataset(Generic[T]):
@ray.remote @ray.remote
def block_to_df(block: ArrowBlock): def block_to_df(block: ArrowBlock):
return block._table.to_pandas() return block.to_pandas()
return [block_to_df.remote(block) for block in self._blocks] return [block_to_df.remote(block) for block in self._blocks]

View file

@ -148,6 +148,9 @@ class ArrowBlock(Block):
def to_pandas(self) -> "pandas.DataFrame": def to_pandas(self) -> "pandas.DataFrame":
return self._table.to_pandas() return self._table.to_pandas()
def to_arrow_table(self) -> "pyarrow.Table":
return self._table
def num_rows(self) -> int: def num_rows(self) -> int:
return self._table.num_rows return self._table.num_rows

View file

@ -12,14 +12,32 @@ T = TypeVar("T")
class BlockBuilder(Generic[T]): class BlockBuilder(Generic[T]):
"""A builder class for blocks."""
def add(self, item: T) -> None: def add(self, item: T) -> None:
"""Append a single row to the block being built."""
raise NotImplementedError
def add_block(self, block: "Block[T]") -> None:
"""Append an entire block to the block being built."""
raise NotImplementedError raise NotImplementedError
def build(self) -> "Block[T]": def build(self) -> "Block[T]":
"""Build the block."""
raise NotImplementedError raise NotImplementedError
class BlockMetadata: class BlockMetadata:
"""Metadata about the block.
Attributes:
num_rows: The number of rows contained in this block, or None.
size_bytes: The approximate size in bytes of this block, or None.
schema: The pyarrow schema or types of the block elements, or None.
input_files: The list of file paths used to generate this block, or
the empty list if indeterminate.
"""
def __init__(self, *, num_rows: Optional[int], size_bytes: Optional[int], def __init__(self, *, num_rows: Optional[int], size_bytes: Optional[int],
schema: Union[type, "pyarrow.lib.Schema"], schema: Union[type, "pyarrow.lib.Schema"],
input_files: List[str]): input_files: List[str]):
@ -32,25 +50,51 @@ class BlockMetadata:
class Block(Generic[T]): class Block(Generic[T]):
"""Represents a batch of rows to be stored in the Ray object store.
There are two types of blocks: ``SimpleBlock``, which is backed by a plain
Python list, and ``ArrowBlock``, which is backed by a ``pyarrow.Table``.
"""
def num_rows(self) -> int: def num_rows(self) -> int:
"""Return the number of rows contained in this block."""
raise NotImplementedError raise NotImplementedError
def iter_rows(self) -> Iterator[T]: def iter_rows(self) -> Iterator[T]:
"""Iterate over the rows of this block."""
raise NotImplementedError raise NotImplementedError
def slice(self, start: int, end: int, copy: bool) -> "Block[T]": def slice(self, start: int, end: int, copy: bool) -> "Block[T]":
"""Return a slice of this block.
Args:
start: The starting index of the slice.
end: The ending index of the slice.
copy: Whether to perform a data copy for the slice.
Returns:
The sliced block result.
"""
raise NotImplementedError raise NotImplementedError
def to_pandas(self) -> "pandas.DataFrame": def to_pandas(self) -> "pandas.DataFrame":
"""Convert this block into a Pandas dataframe."""
raise NotImplementedError
def to_arrow_table(self) -> "pyarrow.Table":
"""Convert this block into an Arrow table."""
raise NotImplementedError raise NotImplementedError
def size_bytes(self) -> int: def size_bytes(self) -> int:
"""Return the approximate size in bytes of this block."""
raise NotImplementedError raise NotImplementedError
def schema(self) -> Any: def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
"""Return the Python type or pyarrow schema of this block."""
raise NotImplementedError raise NotImplementedError
def get_metadata(self, input_files: List[str]) -> BlockMetadata: def get_metadata(self, input_files: List[str]) -> BlockMetadata:
"""Create a metadata object from this block."""
return BlockMetadata( return BlockMetadata(
num_rows=self.num_rows(), num_rows=self.num_rows(),
size_bytes=self.size_bytes(), size_bytes=self.size_bytes(),
@ -59,6 +103,7 @@ class Block(Generic[T]):
@staticmethod @staticmethod
def builder() -> BlockBuilder[T]: def builder() -> BlockBuilder[T]:
"""Create a builder for this block type."""
raise NotImplementedError raise NotImplementedError
@ -96,6 +141,10 @@ class SimpleBlock(Block):
import pandas import pandas
return pandas.DataFrame(self._items) return pandas.DataFrame(self._items)
def to_arrow_table(self) -> "pyarrow.Table":
import pyarrow
return pyarrow.Table.from_pandas(self.to_pandas())
def size_bytes(self) -> int: def size_bytes(self) -> int:
return sys.getsizeof(self._items) return sys.getsizeof(self._items)

View file

@ -1,4 +1,4 @@
from typing import TypeVar, Iterable, Any from typing import TypeVar, Iterable, Any, Union
import ray import ray
from ray.experimental.data.impl.block import Block, BlockMetadata, ObjectRef from ray.experimental.data.impl.block import Block, BlockMetadata, ObjectRef
@ -9,13 +9,13 @@ T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U")
class ComputePool: class ComputeStrategy:
def apply(self, fn: Any, def apply(self, fn: Any,
blocks: Iterable[Block[T]]) -> Iterable[ObjectRef[Block]]: blocks: Iterable[Block[T]]) -> Iterable[ObjectRef[Block]]:
raise NotImplementedError raise NotImplementedError
class TaskPool(ComputePool): class TaskPool(ComputeStrategy):
def apply(self, fn: Any, remote_args: dict, def apply(self, fn: Any, remote_args: dict,
blocks: BlockList[Any]) -> BlockList[Any]: blocks: BlockList[Any]) -> BlockList[Any]:
map_bar = ProgressBar("Map Progress", total=len(blocks)) map_bar = ProgressBar("Map Progress", total=len(blocks))
@ -44,7 +44,7 @@ class TaskPool(ComputePool):
return BlockList(list(new_blocks), list(new_metadata)) return BlockList(list(new_blocks), list(new_metadata))
class ActorPool(ComputePool): class ActorPool(ComputeStrategy):
def apply(self, fn: Any, remote_args: dict, def apply(self, fn: Any, remote_args: dict,
blocks: Iterable[Block[T]]) -> Iterable[ObjectRef[Block]]: blocks: Iterable[Block[T]]) -> Iterable[ObjectRef[Block]]:
@ -113,10 +113,12 @@ class ActorPool(ComputePool):
return BlockList(blocks_out, new_metadata) return BlockList(blocks_out, new_metadata)
def get_compute(compute_spec: str) -> ComputePool: def get_compute(compute_spec: Union[str, ComputeStrategy]) -> ComputeStrategy:
if compute_spec == "tasks": if not compute_spec or compute_spec == "tasks":
return TaskPool() return TaskPool()
elif compute_spec == "actors": elif compute_spec == "actors":
return ActorPool() return ActorPool()
elif isinstance(compute_spec, ComputeStrategy):
return compute_spec
else: else:
raise ValueError("compute must be one of [`tasks`, `actors`]") raise ValueError("compute must be one of [`tasks`, `actors`]")

View file

@ -308,7 +308,7 @@ def read_parquet(paths: Union[str, List[str]],
@ray.remote @ray.remote
def gen_read(pieces: List["pyarrow._dataset.ParquetFileFragment"]): def gen_read(pieces: List["pyarrow._dataset.ParquetFileFragment"]):
import pyarrow import pyarrow
print("Reading {} parquet pieces".format(len(pieces))) logger.debug("Reading {} parquet pieces".format(len(pieces)))
tables = [piece.to_table() for piece in pieces] tables = [piece.to_table() for piece in pieces]
if len(tables) > 1: if len(tables) > 1:
table = pyarrow.concat_tables(tables) table = pyarrow.concat_tables(tables)

View file

@ -229,6 +229,15 @@ def test_parquet_write(ray_start_regular_shared, tmp_path):
assert df.equals(dfds) assert df.equals(dfds)
def test_convert_to_pyarrow(ray_start_regular_shared, tmp_path):
ds = ray.experimental.data.range(100)
assert ds.to_dask().sum().compute()[0] == 4950
path = os.path.join(tmp_path, "test_parquet_dir")
os.mkdir(path)
ds.write_parquet(path)
assert ray.experimental.data.read_parquet(path).count() == 100
def test_pyarrow(ray_start_regular_shared): def test_pyarrow(ray_start_regular_shared):
ds = ray.experimental.data.range_arrow(5) ds = ray.experimental.data.range_arrow(5)
assert ds.map(lambda x: {"b": x["value"] + 2}).take() == \ assert ds.map(lambda x: {"b": x["value"] + 2}).take() == \