mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[dataset] Add datasource API (#16826)
This commit is contained in:
parent
f9daf7fa2c
commit
e77a964640
5 changed files with 374 additions and 75 deletions
|
@ -1,17 +1,23 @@
|
|||
from ray.experimental.data.read_api import from_items, range, range_arrow, \
|
||||
read_parquet, read_json, read_csv, read_binary_files, from_dask, \
|
||||
from_modin, from_pandas, from_spark
|
||||
from_modin, from_pandas, from_spark, read_datasource
|
||||
from ray.experimental.data.datasource import Datasource, ReadTask, WriteTask
|
||||
|
||||
__all__ = [
|
||||
"from_items",
|
||||
"range",
|
||||
"range_arrow",
|
||||
"read_parquet",
|
||||
"read_json",
|
||||
"read_csv",
|
||||
"read_binary_files",
|
||||
"Datasource",
|
||||
"ReadTask",
|
||||
"WriteTask",
|
||||
"from_dask",
|
||||
"from_items",
|
||||
"from_mars",
|
||||
"from_modin",
|
||||
"from_pandas",
|
||||
"from_spark",
|
||||
"range",
|
||||
"range_arrow",
|
||||
"read_binary_files",
|
||||
"read_csv",
|
||||
"read_datasource",
|
||||
"read_json",
|
||||
"read_parquet",
|
||||
]
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
if TYPE_CHECKING:
|
||||
import pyarrow
|
||||
import pandas
|
||||
import mars
|
||||
import modin
|
||||
import dask
|
||||
import pyspark
|
||||
|
@ -14,9 +15,12 @@ if TYPE_CHECKING:
|
|||
|
||||
import collections
|
||||
import itertools
|
||||
import ray
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.experimental.data.datasource import Datasource, WriteTask
|
||||
from ray.experimental.data.impl.compute import get_compute
|
||||
from ray.experimental.data.impl.progress_bar import ProgressBar
|
||||
from ray.experimental.data.impl.shuffle import simple_shuffle
|
||||
from ray.experimental.data.impl.block import ObjectRef, Block, SimpleBlock, \
|
||||
BlockMetadata
|
||||
|
@ -42,7 +46,8 @@ class Dataset(Generic[T]):
|
|||
Since Datasets are just lists of Ray object refs, they can be passed
|
||||
between Ray tasks and actors just like any other object. Datasets support
|
||||
conversion to/from several more featureful dataframe libraries
|
||||
(e.g., Spark, Dask), and are also compatible with TensorFlow / PyTorch.
|
||||
(e.g., Spark, Dask, Modin, MARS), and are also compatible with distributed
|
||||
TensorFlow / PyTorch.
|
||||
|
||||
Dataset supports parallel transformations such as .map(), .map_batches(),
|
||||
and simple repartition, but currently not aggregations and joins.
|
||||
|
@ -100,7 +105,17 @@ class Dataset(Generic[T]):
|
|||
# Transform batches in parallel.
|
||||
>>> ds.map_batches(lambda batch: [v * 2 for v in batch])
|
||||
|
||||
# Transform batches in parallel on GPUs.
|
||||
# Define a batch transform function that persists state across
|
||||
# function invocations for efficiency with compute="actors".
|
||||
>>> def batch_infer_fn(batch):
|
||||
... global model
|
||||
... if model is None:
|
||||
... model = init_model()
|
||||
... return model(batch)
|
||||
|
||||
# Apply the transform in parallel on GPUs. Since compute="actors",
|
||||
# the transform will be applied on an autoscaling pool of Ray
|
||||
# actors, each allocated 1 GPU by Ray.
|
||||
>>> ds.map_batches(
|
||||
... batch_infer_fn,
|
||||
... batch_size=256, compute="actors", num_gpus=1)
|
||||
|
@ -112,7 +127,11 @@ class Dataset(Generic[T]):
|
|||
batch_size: Request a specific batch size, or leave unspecified
|
||||
to use entire blocks as batches.
|
||||
compute: The compute strategy, either "tasks" to use Ray tasks,
|
||||
or "actors" to use an autoscaling Ray actor pool.
|
||||
or "actors" to use an autoscaling Ray actor pool. When using
|
||||
actors, state can be preserved across function invocations
|
||||
in Python global variables. This can be useful for one-time
|
||||
setups, e.g., initializing a model once and re-using it across
|
||||
many function applications.
|
||||
batch_format: Specify "pandas" to select ``pandas.DataFrame`` as
|
||||
the batch format, or "pyarrow" to select ``pyarrow.Table``.
|
||||
ray_remote_args: Additional resource requirements to request from
|
||||
|
@ -710,6 +729,38 @@ class Dataset(Generic[T]):
|
|||
# Block until writing is done.
|
||||
ray.get(refs)
|
||||
|
||||
def write_datasource(self, datasource: Datasource[T],
|
||||
**write_args) -> None:
|
||||
"""Write the dataset to a custom datasource.
|
||||
|
||||
Examples:
|
||||
>>> ds.write_datasource(CustomDatasourceImpl(...))
|
||||
|
||||
Time complexity: O(dataset size / parallelism)
|
||||
|
||||
Args:
|
||||
datasource: The datasource to write to.
|
||||
write_args: Additional write args to pass to the datasource.
|
||||
"""
|
||||
|
||||
write_tasks = datasource.prepare_write(self._blocks, **write_args)
|
||||
progress = ProgressBar("Write Progress", len(write_tasks))
|
||||
|
||||
@ray.remote
|
||||
def remote_write(task: WriteTask) -> Any:
|
||||
return task()
|
||||
|
||||
write_task_outputs = [remote_write.remote(w) for w in write_tasks]
|
||||
try:
|
||||
progress.block_until_complete(write_task_outputs)
|
||||
datasource.on_write_complete(write_tasks,
|
||||
ray.get(write_task_outputs))
|
||||
except Exception as e:
|
||||
datasource.on_write_failed(write_tasks, e)
|
||||
raise
|
||||
finally:
|
||||
progress.close()
|
||||
|
||||
def iter_rows(self, prefetch_blocks: int = 0) -> Iterator[T]:
|
||||
"""Return a local row iterator over the dataset.
|
||||
|
||||
|
@ -811,6 +862,16 @@ class Dataset(Generic[T]):
|
|||
ddf = dd.from_delayed([block_to_df(block) for block in self._blocks])
|
||||
return ddf
|
||||
|
||||
def to_mars(self) -> "mars.DataFrame":
|
||||
"""Convert this dataset into a MARS dataframe.
|
||||
|
||||
Time complexity: O(1)
|
||||
|
||||
Returns:
|
||||
A MARS dataframe created from this dataset.
|
||||
"""
|
||||
raise NotImplementedError # P1
|
||||
|
||||
def to_modin(self) -> "modin.DataFrame":
|
||||
"""Convert this dataset into a Modin dataframe.
|
||||
|
||||
|
|
219
python/ray/experimental/data/datasource.py
Normal file
219
python/ray/experimental/data/datasource.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
import builtins
|
||||
from typing import Any, Generic, List, Callable, Union, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.experimental.data.impl.arrow_block import ArrowRow, ArrowBlock
|
||||
from ray.experimental.data.impl.block import Block, SimpleBlock
|
||||
from ray.experimental.data.impl.block_list import BlockList, BlockMetadata
|
||||
|
||||
T = TypeVar("T")
|
||||
WriteResult = Any
|
||||
|
||||
|
||||
class Datasource(Generic[T]):
|
||||
"""Interface for defining a custom ``ray.data.Dataset`` datasource.
|
||||
|
||||
To read a datasource into a dataset, use ``ray.data.read_datasource()``.
|
||||
To write to a writable datasource, use ``Dataset.write_datasource()``.
|
||||
|
||||
See ``RangeDatasource`` and ``DummyOutputDatasource`` below for examples
|
||||
of how to implement readable and writable datasources.
|
||||
"""
|
||||
|
||||
def prepare_read(self, parallelism: int,
|
||||
**read_args) -> List["ReadTask[T]"]:
|
||||
"""Return the list of tasks needed to perform a read.
|
||||
|
||||
Args:
|
||||
parallelism: The requested read parallelism. The number of read
|
||||
tasks should be as close to this value as possible.
|
||||
read_args: Additional kwargs to pass to the datasource impl.
|
||||
|
||||
Returns:
|
||||
A list of read tasks that can be executed to read blocks from the
|
||||
datasource in parallel.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_write(self, blocks: BlockList,
|
||||
**write_args) -> List["WriteTask[T]"]:
|
||||
"""Return the list of tasks needed to perform a write.
|
||||
|
||||
Args:
|
||||
blocks: List of data block references and block metadata. It is
|
||||
recommended that one write task be generated per block.
|
||||
write_args: Additional kwargs to pass to the datasource impl.
|
||||
|
||||
Returns:
|
||||
A list of write tasks that can be executed to write blocks to the
|
||||
datasource in parallel.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_write_complete(self, write_tasks: List["WriteTask[T]"],
|
||||
write_task_outputs: List[WriteResult],
|
||||
**kwargs) -> None:
|
||||
"""Callback for when a write job completes.
|
||||
|
||||
This can be used to "commit" a write output. This method must
|
||||
succeed prior to ``write_datasource()`` returning to the user. If this
|
||||
method fails, then ``on_write_failed()`` will be called.
|
||||
|
||||
Args:
|
||||
write_tasks: The list of the original write tasks.
|
||||
write_task_outputs: The list of write task outputs.
|
||||
kwargs: Forward-compatibility placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_write_failed(self, write_tasks: List["WriteTask[T]"],
|
||||
error: Exception, **kwargs) -> None:
|
||||
"""Callback for when a write job fails.
|
||||
|
||||
This is called on a best-effort basis on write failures.
|
||||
|
||||
Args:
|
||||
write_tasks: The list of the original write tasks.
|
||||
error: The first error encountered.
|
||||
kwargs: Forward-compatibility placeholder.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ReadTask(Callable[[], Block[T]]):
|
||||
"""A function used to read a block of a dataset.
|
||||
|
||||
Read tasks are generated by ``datasource.prepare_read()``, and return
|
||||
a ``ray.data.Block`` when called. Metadata about the read operation can
|
||||
be retrieved via ``get_metadata()`` prior to executing the read.
|
||||
|
||||
Ray will execute read tasks in remote functions to parallelize execution.
|
||||
"""
|
||||
|
||||
def __init__(self, read_fn: Callable[[], Block[T]],
|
||||
metadata: BlockMetadata):
|
||||
self._metadata = metadata
|
||||
self._read_fn = read_fn
|
||||
|
||||
def get_metadata(self) -> BlockMetadata:
|
||||
return self._metadata
|
||||
|
||||
def __call__(self) -> Block[T]:
|
||||
return self._read_fn()
|
||||
|
||||
|
||||
class WriteTask(Callable[[], WriteResult]):
|
||||
"""A function used to write a chunk of a dataset.
|
||||
|
||||
Write tasks are generated by ``datasource.prepare_write()``, and return
|
||||
a datasource-specific output that is passed to ``on_write_complete()``
|
||||
on write completion.
|
||||
|
||||
Ray will execute write tasks in remote functions to parallelize execution.
|
||||
"""
|
||||
|
||||
def __init__(self, write_fn: Callable[[], WriteResult]):
|
||||
self._write_fn = write_fn
|
||||
|
||||
def __call__(self) -> WriteResult:
|
||||
return self._write_fn()
|
||||
|
||||
|
||||
class RangeDatasource(Datasource[Union[ArrowRow, int]]):
|
||||
"""An example datasource that generates ranges of numbers from [0..n).
|
||||
|
||||
Examples:
|
||||
>>> source = RangeDatasource()
|
||||
>>> ray.data.read_datasource(source, n=10).take()
|
||||
... [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
"""
|
||||
|
||||
def prepare_read(self, parallelism: int, n: int,
|
||||
use_arrow: bool) -> List[ReadTask]:
|
||||
read_tasks: List[ReadTask] = []
|
||||
block_size = max(1, n // parallelism)
|
||||
|
||||
# Example of a read task. In a real datasource, this would pull data
|
||||
# from an external system instead of generating dummy data.
|
||||
def make_block(start: int, count: int) -> Block[Union[ArrowRow, int]]:
|
||||
if use_arrow:
|
||||
return ArrowBlock(
|
||||
pyarrow.Table.from_arrays(
|
||||
[np.arange(start, start + count)], names=["value"]))
|
||||
else:
|
||||
return SimpleBlock(list(builtins.range(start, start + count)))
|
||||
|
||||
i = 0
|
||||
while i < n:
|
||||
count = min(block_size, n - i)
|
||||
if use_arrow:
|
||||
import pyarrow
|
||||
schema = pyarrow.Table.from_pydict({"value": [0]}).schema
|
||||
else:
|
||||
schema = int
|
||||
read_tasks.append(
|
||||
ReadTask(
|
||||
lambda i=i, count=count: make_block(i, count),
|
||||
BlockMetadata(
|
||||
num_rows=count,
|
||||
size_bytes=8 * count,
|
||||
schema=schema,
|
||||
input_files=None)))
|
||||
i += block_size
|
||||
|
||||
return read_tasks
|
||||
|
||||
|
||||
class DummyOutputDatasource(Datasource[Union[ArrowRow, int]]):
|
||||
"""An example implementation of a writable datasource for testing.
|
||||
|
||||
Examples:
|
||||
>>> output = DummyOutputDatasource()
|
||||
>>> ray.data.range(10).write_datasource(output)
|
||||
>>> assert output.num_ok == 1
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Setup a dummy actor to send the data. In a real datasource, write
|
||||
# tasks would send data to an external system instead of a Ray actor.
|
||||
@ray.remote
|
||||
class DataSink:
|
||||
def __init__(self):
|
||||
self.rows_written = 0
|
||||
self.enabled = True
|
||||
|
||||
def write(self, block: Block[T]) -> str:
|
||||
if not self.enabled:
|
||||
raise ValueError("disabled")
|
||||
self.rows_written += block.num_rows()
|
||||
return "ok"
|
||||
|
||||
def get_rows_written(self):
|
||||
return self.rows_written
|
||||
|
||||
def set_enabled(self, enabled):
|
||||
self.enabled = enabled
|
||||
|
||||
self.data_sink = DataSink.remote()
|
||||
self.num_ok = 0
|
||||
self.num_failed = 0
|
||||
|
||||
def prepare_write(self, blocks: BlockList,
|
||||
**write_args) -> List["WriteTask[T]"]:
|
||||
tasks = []
|
||||
for b in blocks:
|
||||
tasks.append(
|
||||
WriteTask(lambda b=b: ray.get(self.data_sink.write.remote(b))))
|
||||
return tasks
|
||||
|
||||
def on_write_complete(self, write_tasks: List["WriteTask[T]"],
|
||||
write_task_outputs: List[WriteResult]) -> None:
|
||||
assert len(write_task_outputs) == len(write_tasks)
|
||||
assert all(w == "ok" for w in write_task_outputs), write_task_outputs
|
||||
self.num_ok += 1
|
||||
|
||||
def on_write_failed(self, write_tasks: List["WriteTask[T]"],
|
||||
error: Exception) -> None:
|
||||
self.num_failed += 1
|
|
@ -1,17 +1,22 @@
|
|||
import logging
|
||||
import functools
|
||||
import builtins
|
||||
from typing import List, Any, Union, Optional, Tuple, Callable, TYPE_CHECKING
|
||||
import inspect
|
||||
from typing import List, Any, Union, Optional, Tuple, Callable, TypeVar, \
|
||||
TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pyarrow
|
||||
import pandas
|
||||
import dask
|
||||
import mars
|
||||
import modin
|
||||
import pyspark
|
||||
|
||||
import ray
|
||||
from ray.experimental.data.dataset import Dataset
|
||||
from ray.experimental.data.datasource import Datasource, RangeDatasource, \
|
||||
ReadTask
|
||||
from ray.experimental.data.impl import reader as _reader
|
||||
from ray.experimental.data.impl.arrow_block import ArrowBlock, ArrowRow
|
||||
from ray.experimental.data.impl.block import ObjectRef, SimpleBlock, Block, \
|
||||
|
@ -19,16 +24,19 @@ from ray.experimental.data.impl.block import ObjectRef, SimpleBlock, Block, \
|
|||
from ray.experimental.data.impl.block_list import BlockList
|
||||
from ray.experimental.data.impl.lazy_block_list import LazyBlockList
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def autoinit_ray(f):
|
||||
def autoinit_ray(f: Callable) -> Callable:
|
||||
@functools.wraps(f)
|
||||
def wrapped(*a, **kw):
|
||||
if not ray.is_initialized():
|
||||
ray.client().connect()
|
||||
return f(*a, **kw)
|
||||
|
||||
setattr(wrapped, "__signature__", inspect.signature(f))
|
||||
return wrapped
|
||||
|
||||
|
||||
|
@ -170,34 +178,8 @@ def range(n: int, parallelism: int = 200) -> Dataset[int]:
|
|||
Returns:
|
||||
Dataset holding the integers.
|
||||
"""
|
||||
calls: List[Callable[[], ObjectRef[Block]]] = []
|
||||
metadata: List[BlockMetadata] = []
|
||||
block_size = max(1, n // parallelism)
|
||||
|
||||
@ray.remote
|
||||
def gen_block(start: int, count: int) -> SimpleBlock:
|
||||
builder = SimpleBlock.builder()
|
||||
for value in builtins.range(start, start + count):
|
||||
builder.add(value)
|
||||
return builder.build()
|
||||
|
||||
i = 0
|
||||
while i < n:
|
||||
|
||||
def make_call(start: int, count: int) -> ObjectRef[Block]:
|
||||
return lambda: gen_block.remote(start, count)
|
||||
|
||||
count = min(block_size, n - i)
|
||||
calls.append(make_call(i, count))
|
||||
metadata.append(
|
||||
BlockMetadata(
|
||||
num_rows=count,
|
||||
size_bytes=8 * count,
|
||||
schema=int,
|
||||
input_files=None))
|
||||
i += block_size
|
||||
|
||||
return Dataset(LazyBlockList(calls, metadata))
|
||||
return read_datasource(
|
||||
RangeDatasource(), parallelism=parallelism, n=n, use_arrow=False)
|
||||
|
||||
|
||||
@autoinit_ray
|
||||
|
@ -217,36 +199,37 @@ def range_arrow(n: int, parallelism: int = 200) -> Dataset[ArrowRow]:
|
|||
Returns:
|
||||
Dataset holding the integers as Arrow records.
|
||||
"""
|
||||
import pyarrow
|
||||
return read_datasource(
|
||||
RangeDatasource(), parallelism=parallelism, n=n, use_arrow=True)
|
||||
|
||||
calls: List[Callable[[], ObjectRef[Block]]] = []
|
||||
metadata: List[BlockMetadata] = []
|
||||
block_size = max(1, n // parallelism)
|
||||
i = 0
|
||||
|
||||
@autoinit_ray
|
||||
def read_datasource(datasource: Datasource[T],
|
||||
parallelism: int = 200,
|
||||
**read_args) -> Dataset[T]:
|
||||
"""Read a dataset from a custom data source.
|
||||
|
||||
Args:
|
||||
datasource: The datasource to read data from.
|
||||
parallelism: The requested parallelism of the read.
|
||||
read_args: Additional kwargs to pass to the datasource impl.
|
||||
|
||||
Returns:
|
||||
Dataset holding the data read from the datasource.
|
||||
"""
|
||||
|
||||
read_tasks = datasource.prepare_read(parallelism, **read_args)
|
||||
|
||||
@ray.remote
|
||||
def gen_block(start: int, count: int) -> "ArrowBlock":
|
||||
return ArrowBlock(
|
||||
pyarrow.Table.from_pydict({
|
||||
"value": list(builtins.range(start, start + count))
|
||||
}))
|
||||
def remote_read(task: ReadTask) -> Block[T]:
|
||||
return task()
|
||||
|
||||
while i < n:
|
||||
calls: List[Callable[[], ObjectRef[Block[T]]]] = []
|
||||
metadata: List[BlockMetadata] = []
|
||||
|
||||
def make_call(start: int, count: int) -> ObjectRef[Block]:
|
||||
return lambda: gen_block.remote(start, count)
|
||||
|
||||
start = block_size * i
|
||||
count = min(block_size, n - i)
|
||||
calls.append(make_call(start, count))
|
||||
schema = pyarrow.Table.from_pydict({"value": [0]}).schema
|
||||
metadata.append(
|
||||
BlockMetadata(
|
||||
num_rows=count,
|
||||
size_bytes=8 * count,
|
||||
schema=schema,
|
||||
input_files=None))
|
||||
i += block_size
|
||||
for task in read_tasks:
|
||||
calls.append(lambda task=task: remote_read.remote(task))
|
||||
metadata.append(task.get_metadata())
|
||||
|
||||
return Dataset(LazyBlockList(calls, metadata))
|
||||
|
||||
|
@ -305,12 +288,7 @@ def read_parquet(paths: Union[str, List[str]],
|
|||
calls: List[Callable[[], ObjectRef[Block]]] = []
|
||||
metadata: List[BlockMetadata] = []
|
||||
for pieces in nonempty_tasks:
|
||||
|
||||
def make_call(
|
||||
pieces: List[pq.ParquetDatasetPiece]) -> ObjectRef[Block]:
|
||||
return lambda: gen_read.remote(pieces)
|
||||
|
||||
calls.append(make_call(pieces))
|
||||
calls.append(lambda pieces=pieces: gen_read.remote(pieces))
|
||||
piece_metadata = [p.get_metadata() for p in pieces]
|
||||
metadata.append(
|
||||
BlockMetadata(
|
||||
|
@ -474,7 +452,9 @@ def read_binary_files(
|
|||
filesystem=filesystem))
|
||||
|
||||
|
||||
def from_dask(df: "dask.DataFrame") -> Dataset[ArrowRow]:
|
||||
@autoinit_ray
|
||||
def from_dask(df: "dask.DataFrame",
|
||||
parallelism: int = 200) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a Dask DataFrame.
|
||||
|
||||
Args:
|
||||
|
@ -492,6 +472,21 @@ def from_dask(df: "dask.DataFrame") -> Dataset[ArrowRow]:
|
|||
[next(iter(part.dask.values())) for part in persisted_partitions])
|
||||
|
||||
|
||||
@autoinit_ray
|
||||
def from_mars(df: "mars.DataFrame",
|
||||
parallelism: int = 200) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a MARS dataframe.
|
||||
|
||||
Args:
|
||||
df: A MARS dataframe, which must be executed by MARS-on-Ray.
|
||||
|
||||
Returns:
|
||||
Dataset holding Arrow records read from the dataframe.
|
||||
"""
|
||||
raise NotImplementedError # P1
|
||||
|
||||
|
||||
@autoinit_ray
|
||||
def from_modin(df: "modin.DataFrame",
|
||||
parallelism: int = 200) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a Modin dataframe.
|
||||
|
@ -506,6 +501,7 @@ def from_modin(df: "modin.DataFrame",
|
|||
raise NotImplementedError # P1
|
||||
|
||||
|
||||
@autoinit_ray
|
||||
def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]],
|
||||
parallelism: int = 200) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a set of Pandas dataframes.
|
||||
|
@ -529,6 +525,7 @@ def from_pandas(dfs: List[ObjectRef["pandas.DataFrame"]],
|
|||
return Dataset(BlockList(blocks, ray.get(list(metadata))))
|
||||
|
||||
|
||||
@autoinit_ray
|
||||
def from_spark(df: "pyspark.sql.DataFrame",
|
||||
parallelism: int = 200) -> Dataset[ArrowRow]:
|
||||
"""Create a dataset from a Spark dataframe.
|
||||
|
|
|
@ -10,8 +10,8 @@ import pytest
|
|||
import ray
|
||||
|
||||
from ray.util.dask import ray_dask_get
|
||||
|
||||
from ray.tests.conftest import * # noqa
|
||||
from ray.experimental.data.datasource import DummyOutputDatasource
|
||||
import ray.experimental.data.tests.util as util
|
||||
|
||||
|
||||
|
@ -29,6 +29,22 @@ def test_basic(ray_start_regular_shared):
|
|||
assert sorted(ds.iter_rows()) == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_write_datasource(ray_start_regular_shared):
|
||||
output = DummyOutputDatasource()
|
||||
ds = ray.experimental.data.range(10, parallelism=2)
|
||||
ds.write_datasource(output)
|
||||
assert output.num_ok == 1
|
||||
assert output.num_failed == 0
|
||||
assert ray.get(output.data_sink.get_rows_written.remote()) == 10
|
||||
|
||||
ray.get(output.data_sink.set_enabled.remote(False))
|
||||
with pytest.raises(ValueError):
|
||||
ds.write_datasource(output)
|
||||
assert output.num_ok == 1
|
||||
assert output.num_failed == 1
|
||||
assert ray.get(output.data_sink.get_rows_written.remote()) == 10
|
||||
|
||||
|
||||
def test_empty_dataset(ray_start_regular_shared):
|
||||
ds = ray.experimental.data.range(0)
|
||||
assert ds.count() == 0
|
||||
|
|
Loading…
Add table
Reference in a new issue