[dataset] Add datasource API (#16826)

This commit is contained in:
Eric Liang 2021-07-01 23:44:30 -07:00 committed by GitHub
parent f9daf7fa2c
commit e77a964640
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 374 additions and 75 deletions

View file

@ -1,17 +1,23 @@
from ray.experimental.data.read_api import from_items, range, range_arrow, \
read_parquet, read_json, read_csv, read_binary_files, from_dask, \
from_modin, from_pandas, from_spark
from_modin, from_pandas, from_spark, read_datasource
from ray.experimental.data.datasource import Datasource, ReadTask, WriteTask
__all__ = [
"from_items",
"range",
"range_arrow",
"read_parquet",
"read_json",
"read_csv",
"read_binary_files",
"Datasource",
"ReadTask",
"WriteTask",
"from_dask",
"from_items",
"from_mars",
"from_modin",
"from_pandas",
"from_spark",
"range",
"range_arrow",
"read_binary_files",
"read_csv",
"read_datasource",
"read_json",
"read_parquet",
]

View file

@ -7,6 +7,7 @@ import os
if TYPE_CHECKING:
import pyarrow
import pandas
import mars
import modin
import dask
import pyspark
@ -14,9 +15,12 @@ if TYPE_CHECKING:
import collections
import itertools
import ray
import numpy as np
import ray
from ray.experimental.data.datasource import Datasource, WriteTask
from ray.experimental.data.impl.compute import get_compute
from ray.experimental.data.impl.progress_bar import ProgressBar
from ray.experimental.data.impl.shuffle import simple_shuffle
from ray.experimental.data.impl.block import ObjectRef, Block, SimpleBlock, \
BlockMetadata
@ -42,7 +46,8 @@ class Dataset(Generic[T]):
Since Datasets are just lists of Ray object refs, they can be passed
between Ray tasks and actors just like any other object. Datasets support
conversion to/from several more featureful dataframe libraries
(e.g., Spark, Dask), and are also compatible with TensorFlow / PyTorch.
(e.g., Spark, Dask, Modin, MARS), and are also compatible with distributed
TensorFlow / PyTorch.
Dataset supports parallel transformations such as .map(), .map_batches(),
and simple repartition, but currently not aggregations and joins.
@ -100,7 +105,17 @@ class Dataset(Generic[T]):
# Transform batches in parallel.
>>> ds.map_batches(lambda batch: [v * 2 for v in batch])
# Transform batches in parallel on GPUs.
# Define a batch transform function that persists state across
# function invocations for efficiency with compute="actors".
>>> def batch_infer_fn(batch):
... global model
... if model is None:
... model = init_model()
... return model(batch)
# Apply the transform in parallel on GPUs. Since compute="actors",
# the transform will be applied on an autoscaling pool of Ray
# actors, each allocated 1 GPU by Ray.
>>> ds.map_batches(
... batch_infer_fn,
... batch_size=256, compute="actors", num_gpus=1)
@ -112,7 +127,11 @@ class Dataset(Generic[T]):
batch_size: Request a specific batch size, or leave unspecified
to use entire blocks as batches.
compute: The compute strategy, either "tasks" to use Ray tasks,
or "actors" to use an autoscaling Ray actor pool.
or "actors" to use an autoscaling Ray actor pool. When using
actors, state can be preserved across function invocations
in Python global variables. This can be useful for one-time
setups, e.g., initializing a model once and re-using it across
many function applications.
batch_format: Specify "pandas" to select ``pandas.DataFrame`` as
the batch format, or "pyarrow" to select ``pyarrow.Table``.
ray_remote_args: Additional resource requirements to request from
@ -710,6 +729,38 @@ class Dataset(Generic[T]):
# Block until writing is done.
ray.get(refs)
def write_datasource(self, datasource: Datasource[T],
**write_args) -> None:
"""Write the dataset to a custom datasource.
Examples:
>>> ds.write_datasource(CustomDatasourceImpl(...))
Time complexity: O(dataset size / parallelism)
Args:
datasource: The datasource to write to.
write_args: Additional write args to pass to the datasource.
"""
write_tasks = datasource.prepare_write(self._blocks, **write_args)
progress = ProgressBar("Write Progress", len(write_tasks))
@ray.remote
def remote_write(task: WriteTask) -> Any:
return task()
write_task_outputs = [remote_write.remote(w) for w in write_tasks]
try:
progress.block_until_complete(write_task_outputs)
datasource.on_write_complete(write_tasks,
ray.get(write_task_outputs))
except Exception as e:
datasource.on_write_failed(write_tasks, e)
raise
finally:
progress.close()
def iter_rows(self, prefetch_blocks: int = 0) -> Iterator[T]:
"""Return a local row iterator over the dataset.
@ -811,6 +862,16 @@ class Dataset(Generic[T]):
ddf = dd.from_delayed([block_to_df(block) for block in self._blocks])
return ddf
def to_mars(self) -> "mars.DataFrame":
"""Convert this dataset into a MARS dataframe.
Time complexity: O(1)
Returns:
A MARS dataframe created from this dataset.
"""
raise NotImplementedError # P1
def to_modin(self) -> "modin.DataFrame":
"""Convert this dataset into a Modin dataframe.

View file

@ -0,0 +1,219 @@
import builtins
from typing import Any, Generic, List, Callable, Union, TypeVar
import numpy as np
import ray
from ray.experimental.data.impl.arrow_block import ArrowRow, ArrowBlock
from ray.experimental.data.impl.block import Block, SimpleBlock
from ray.experimental.data.impl.block_list import BlockList, BlockMetadata
T = TypeVar("T")
WriteResult = Any
class Datasource(Generic[T]):
"""Interface for defining a custom ``ray.data.Dataset`` datasource.
To read a datasource into a dataset, use ``ray.data.read_datasource()``.
To write to a writable datasource, use ``Dataset.write_datasource()``.
See ``RangeDatasource`` and ``DummyOutputDatasource`` below for examples
of how to implement readable and writable datasources.
"""
def prepare_read(self, parallelism: int,
**read_args) -> List["ReadTask[T]"]:
"""Return the list of tasks needed to perform a read.
Args:
parallelism: The requested read parallelism. The number of read
tasks should be as close to this value as possible.
read_args: Additional kwargs to pass to the datasource impl.
Returns:
A list of read tasks that can be executed to read blocks from the
datasource in parallel.
"""
raise NotImplementedError
def prepare_write(self, blocks: BlockList,
**write_args) -> List["WriteTask[T]"]:
"""Return the list of tasks needed to perform a write.
Args:
blocks: List of data block references and block metadata. It is
recommended that one write task be generated per block.
write_args: Additional kwargs to pass to the datasource impl.
Returns:
A list of write tasks that can be executed to write blocks to the
datasource in parallel.
"""
raise NotImplementedError
def on_write_complete(self, write_tasks: List["WriteTask[T]"],
write_task_outputs: List[WriteResult],
**kwargs) -> None:
"""Callback for when a write job completes.
This can be used to "commit" a write output. This method must
succeed prior to ``write_datasource()`` returning to the user. If this
method fails, then ``on_write_failed()`` will be called.
Args:
write_tasks: The list of the original write tasks.
write_task_outputs: The list of write task outputs.
kwargs: Forward-compatibility placeholder.
"""
pass
def on_write_failed(self, write_tasks: List["WriteTask[T]"],
error: Exception, **kwargs) -> None:
"""Callback for when a write job fails.
This is called on a best-effort basis on write failures.
Args:
write_tasks: The list of the original write tasks.
error: The first error encountered.
kwargs: Forward-compatibility placeholder.
"""
pass
class ReadTask(Callable[[], Block[T]]):
"""A function used to read a block of a dataset.
Read tasks are generated by ``datasource.prepare_read()``, and return
a ``ray.data.Block`` when called. Metadata about the read operation can
be retrieved via ``get_metadata()`` prior to executing the read.
Ray will execute read tasks in remote functions to parallelize execution.
"""
def __init__(self, read_fn: Callable[[], Block[T]],
metadata: BlockMetadata):
self._metadata = metadata
self._read_fn = read_fn
def get_metadata(self) -> BlockMetadata:
return self._metadata
def __call__(self) -> Block[T]:
return self._read_fn()
class WriteTask(Callable[[], WriteResult]):
"""A function used to write a chunk of a dataset.
Write tasks are generated by ``datasource.prepare_write()``, and return
a datasource-specific output that is passed to ``on_write_complete()``
on write completion.
Ray will execute write tasks in remote functions to parallelize execution.
"""
def __init__(self, write_fn: Callable[[], WriteResult]):
self._write_fn = write_fn
def __call__(self) -> WriteResult:
return self._write_fn()
class RangeDatasource(Datasource[Union[ArrowRow, int]]):
"""An example datasource that generates ranges of numbers from [0..n).
Examples:
>>> source = RangeDatasource()
>>> ray.data.read_datasource(source, n=10).take()
... [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
def prepare_read(self, parallelism: int, n: int,
use_arrow: bool) -> List[ReadTask]:
read_tasks: List[ReadTask] = []
block_size = max(1, n // parallelism)
# Example of a read task. In a real datasource, this would pull data
# from an external system instead of generating dummy data.
def make_block(start: int, count: int) -> Block[Union[ArrowRow, int]]:
if use_arrow:
return ArrowBlock(
pyarrow.Table.from_arrays(
[np.arange(start, start + count)], names=["value"]))
else:
return SimpleBlock(list(builtins.range(start, start + count)))
i = 0
while i < n:
count = min(block_size, n - i)
if use_arrow:
import pyarrow
schema = pyarrow.Table.from_pydict({"value": [0]}).schema
else:
schema = int
read_tasks.append(
ReadTask(
lambda i=i, count=count: make_block(i, count),
BlockMetadata(
num_rows=count,
size_bytes=8 * count,
schema=schema,
input_files=None)))
i += block_size
return read_tasks
class DummyOutputDatasource(Datasource[Union[ArrowRow, int]]):
"""An example implementation of a writable datasource for testing.
Examples:
>>> output = DummyOutputDatasource()
>>> ray.data.range(10).write_datasource(output)
>>> assert output.num_ok == 1
"""
def __init__(self):
# Setup a dummy actor to send the data. In a real datasource, write
# tasks would send data to an external system instead of a Ray actor.
@ray.remote
class DataSink:
def __init__(self):
self.rows_written = 0
self.enabled = True
def write(self, block: Block[T]) -> str:
if not self.enabled:
raise ValueError("disabled")
self.rows_written += block.num_rows()
return "ok"
def get_rows_written(self):
return self.rows_written
def set_enabled(self, enabled):
self.enabled = enabled
self.data_sink = DataSink.remote()
self.num_ok = 0
self.num_failed = 0
def prepare_write(self, blocks: BlockList,
**write_args) -> List["WriteTask[T]"]:
tasks = []
for b in blocks:
tasks.append(
WriteTask(lambda b=b: ray.get(self.data_sink.write.remote(b))))
return tasks
def on_write_complete(self, write_tasks: List["WriteTask[T]"],
write_task_outputs: List[WriteResult]) -> None:
assert len(write_task_outputs) == len(write_tasks)
assert all(w == "ok" for w in write_task_outputs), write_task_outputs
self.num_ok += 1
def on_write_failed(self, write_tasks: List["WriteTask[T]"],
error: Exception) -> None:
self.num_failed += 1

View file

@ -1,17 +1,22 @@
import logging
import functools
import builtins
from typing import List, Any, Union, Optional, Tuple, Callable, TYPE_CHECKING
import inspect
from typing import List, Any, Union, Optional, Tuple, Callable, TypeVar, \
TYPE_CHECKING
if TYPE_CHECKING:
import pyarrow
import pandas
import dask
import mars
import modin
import pyspark
import ray
from ray.experimental.data.dataset import Dataset
from ray.experimental.data.datasource import Datasource, RangeDatasource, \
ReadTask
from ray.experimental.data.impl import reader as _reader
from ray.experimental.data.impl.arrow_block import ArrowBlock, ArrowRow
from ray.experimental.data.impl.block import ObjectRef, SimpleBlock, Block, \
@ -19,16 +24,19 @@ from ray.experimental.data.impl.block import ObjectRef, SimpleBlock, Block, \
from ray.experimental.data.impl.block_list import BlockList
from ray.experimental.data.impl.lazy_block_list import LazyBlockList
T = TypeVar("T")
logger = logging.getLogger(__name__)
def autoinit_ray(f):
def autoinit_ray(f: Callable) -> Callable:
@functools.wraps(f)
def wrapped(*a, **kw):
if not ray.is_initialized():
ray.client().connect()
return f(*a, **kw)
setattr(wrapped, "__signature__", inspect.signature(f))
return wrapped
@ -170,34 +178,8 @@ def range(n: int, parallelism: int = 200) -> Dataset[int]:
Returns:
Dataset holding the integers.
"""
calls: List[Callable[[], ObjectRef[Block]]] = []
metadata: List[BlockMetadata] = []
block_size = max(1, n // parallelism)
@ray.remote
def gen_block(start: int, count: int) -> SimpleBlock:
builder = SimpleBlock.builder()
for value in builtins.range(start, start + count):
builder.add(value)
return builder.build()
i = 0
while i < n:
def make_call(start: int, count: int) -> ObjectRef[Block]:
return lambda: gen_block.remote(start, count)
count = min(block_size, n - i)
calls.append(make_call(i, count))
metadata.append(
BlockMetadata(
num_rows=count,
size_bytes=8 * count,
schema=int,
input_files=None))
i += block_size
return Dataset(LazyBlockList(calls, metadata))
return read_datasource(
RangeDatasource(), parallelism=parallelism, n=n, use_arrow=False)
@autoinit_ray
@ -217,36 +199,37 @@ def range_arrow(n: int, parallelism: int = 200) -> Dataset[ArrowRow]:
Returns:
Dataset holding the integers as Arrow records.
"""
import pyarrow
return read_datasource(
RangeDatasource(), parallelism=parallelism, n=n, use_arrow=True)
calls: List[Callable[[], ObjectRef[Block]]] = []
metadata: List[BlockMetadata] = []
block_size = max(1, n // parallelism)
i = 0
@autoinit_ray
def read_datasource(datasource: Datasource[T],
parallelism: int = 200,
**read_args) -> Dataset[T]:
"""Read a dataset from a custom data source.
Args:
datasource: The datasource to read data from.
parallelism: The requested parallelism of the read.
read_args: Additional kwargs to pass to the datasource impl.
Returns:
Dataset holding the data read from the datasource.
"""
read_tasks = datasource.prepare_read(parallelism, **read_args)
@ray.remote
def gen_block(start: int, count: int) -> "ArrowBlock":
return ArrowBlock(
pyarrow.Table.from_pydict({
"value": list(builtins.range(start, start + count))
}))
def remote_read(task: ReadTask) -> Block[T]:
return task()
while i < n:
calls: List[Callable[[], ObjectRef[Block[T]]]] = []
metadata: List[BlockMetadata] = []
def make_call(start: int, count: int) -> ObjectRef[Block]:
return lambda: gen_block.remote(start, count)
start = block_size * i
count = min(block_size, n - i)
calls.append(make_call(start, count))
schema = pyarrow.Table.from_pydict({"value": [0]}).schema
metadata.append(
BlockMetadata(
num_rows=count,
size_bytes=8 * count,
schema=schema,
input_files=None))
i += block_size
for task in read_tasks:
calls.append(lambda task=task: remote_read.remote(task))
metadata.append(task.get_metadata())
return Dataset(LazyBlockList(calls, metadata))
@ -305,12 +288,7 @@ def read_parquet(paths: Union[str, List[str]],
calls: List[Callable[[], ObjectRef[Block]]] = []
metadata: List[BlockMetadata] = []
for pieces in nonempty_tasks:
def make_call(
pieces: List[pq.ParquetDatasetPiece]) -> ObjectRef[Block]:
return lambda: gen_read.remote(pieces)
calls.append(make_call(pieces))
calls.append(lambda pieces=pieces: gen_read.remote(pieces))
piece_metadata = [p.get_metadata() for p in pieces]
metadata.append(
BlockMetadata(
@ -474,7 +452,9 @@ def read_binary_files(
filesystem=filesystem))
def from_dask(df: "dask.DataFrame") -> Dataset[ArrowRow]:
@autoinit_ray
def from_dask(df: "dask.DataFrame",
parallelism: int = 200) -> Dataset[ArrowRow]:
"""Create a dataset from a Dask DataFrame.
Args:
@ -492,6 +472,21 @@ def from_dask(df: "dask.DataFrame") -> Dataset[ArrowRow]:
[next(iter(part.dask.values())) for part in persisted_partitions])
@autoinit_ray
def from_mars(df: "mars.DataFrame",
parallelism: int = 200) -> Dataset[ArrowRow]:
"""Create a dataset from a MARS dataframe.
Args:
df: A MARS dataframe, which must be executed by MARS-on-Ray.
Returns:
Dataset holding Arrow records read from the dataframe.
"""
raise NotImplementedError # P1
@autoinit_ray
def from_modin(df: "modin.DataFrame",
parallelism: int = 200) -> Dataset[ArrowRow]:
"""Create a dataset from a Modin dataframe.
@ -506,6 +501,7 @@ def from_modin(df: "modin.DataFrame",
raise NotImplementedError # P1
@autoinit_ray
def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]],
parallelism: int = 200) -> Dataset[ArrowRow]:
"""Create a dataset from a set of Pandas dataframes.
@ -529,6 +525,7 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]],
return Dataset(BlockList(blocks, ray.get(list(metadata))))
@autoinit_ray
def from_spark(df: "pyspark.sql.DataFrame",
parallelism: int = 200) -> Dataset[ArrowRow]:
"""Create a dataset from a Spark dataframe.

View file

@ -10,8 +10,8 @@ import pytest
import ray
from ray.util.dask import ray_dask_get
from ray.tests.conftest import * # noqa
from ray.experimental.data.datasource import DummyOutputDatasource
import ray.experimental.data.tests.util as util
@ -29,6 +29,22 @@ def test_basic(ray_start_regular_shared):
assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4]
def test_write_datasource(ray_start_regular_shared):
output = DummyOutputDatasource()
ds = ray.experimental.data.range(10, parallelism=2)
ds.write_datasource(output)
assert output.num_ok == 1
assert output.num_failed == 0
assert ray.get(output.data_sink.get_rows_written.remote()) == 10
ray.get(output.data_sink.set_enabled.remote(False))
with pytest.raises(ValueError):
ds.write_datasource(output)
assert output.num_ok == 1
assert output.num_failed == 1
assert ray.get(output.data_sink.get_rows_written.remote()) == 10
def test_empty_dataset(ray_start_regular_shared):
ds = ray.experimental.data.range(0)
assert ds.count() == 0