mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Rename Dataset.pipeline to Dataset.window (#19050)
This commit is contained in:
parent
3dc176c42e
commit
032a420ee6
7 changed files with 66 additions and 34 deletions
|
@ -11,7 +11,11 @@ A DatasetPipeline is an unified iterator over a (potentially infinite) sequence
|
|||
Creating a DatasetPipeline
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.pipeline``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example:
|
||||
A DatasetPipeline can be constructed in two ways: either by pipelining the execution of an existing Dataset (via ``Dataset.window``), or generating repeats of an existing Dataset (via ``Dataset.repeat``). Similar to Datasets, you can freely pass DatasetPipelines between Ray tasks, actors, and libraries. Get started with this synthetic data example:
|
||||
|
||||
.. tip::
|
||||
|
||||
The "window size" of a pipeline is defined as the number of blocks per Dataset in the pipeline.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -30,7 +34,7 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu
|
|||
base = ray.data.range(1000000)
|
||||
print(base)
|
||||
# -> Dataset(num_blocks=200, num_rows=1000000, schema=<class 'int'>)
|
||||
pipe = base.pipeline(parallelism=10)
|
||||
pipe = base.window(blocks_per_window=10)
|
||||
print(pipe)
|
||||
# -> DatasetPipeline(length=20, num_stages=1)
|
||||
|
||||
|
@ -53,8 +57,7 @@ A DatasetPipeline can be constructed in two ways: either by pipelining the execu
|
|||
print("Total num rows", num_rows)
|
||||
# -> Total num rows 1000000
|
||||
|
||||
|
||||
You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.pipeline`` using ``from_iterable``:
|
||||
You can also create a DatasetPipeline from a custom iterator over dataset creators using ``DatasetPipeline.from_iterable``. For example, this is how you would implement ``Dataset.repeat`` and ``Dataset.window`` using ``from_iterable``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -66,7 +69,7 @@ You can also create a DatasetPipeline from a custom iterator over dataset creato
|
|||
pipe = DatasetPipeline.from_iterable(
|
||||
[lambda: source, lambda: source, lambda: source, lambda: source])
|
||||
|
||||
# Equivalent to ray.data.range(1000).pipeline(parallelism=10)
|
||||
# Equivalent to ray.data.range(1000).window(blocks_per_window=10)
|
||||
splits = ray.data.range(1000, parallelism=200).split(20)
|
||||
pipe = DatasetPipeline.from_iterable([lambda s=s: s for s in splits])
|
||||
|
||||
|
@ -109,28 +112,28 @@ Ignoring the output, the above script has three separate stages: loading, prepro
|
|||
Enabling Pipelining
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
We can optimize this by *pipelining* the execution of the dataset with the ``.pipeline()`` call, which returns a DatasetPIpeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset:
|
||||
We can optimize this by *pipelining* the execution of the dataset with the ``.window()`` call, which returns a DatasetPipeline instead of a Dataset object. The pipeline supports similar transformations to the original Dataset:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Convert the Dataset into a DatasetPipeline.
|
||||
pipe: DatasetPipeline = ray.data \
|
||||
.read_binary_files("s3://bucket/image-dir") \
|
||||
.pipeline(parallelism=2)
|
||||
.window(blocks_per_window=2)
|
||||
|
||||
# The remainder of the steps do not change.
|
||||
pipe = pipe.map(preprocess)
|
||||
pipe = pipe.map_batches(BatchInferModel, compute="actors", batch_size=256, num_gpus=1)
|
||||
pipe.write_json("/tmp/results")
|
||||
|
||||
Here we specified ``parallelism=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time:
|
||||
Here we specified ``blocks_per_window=2``, which means that the Dataset is split into smaller sub-Datasets of two blocks each. Each transformation or *stage* of the pipeline is operating over these two-block Datasets in parallel. This means batch inference processing can start as soon as two blocks are read and preprocessed, greatly reducing the GPU idle time:
|
||||
|
||||
.. image:: dataset-pipeline-2.svg
|
||||
|
||||
Tuning Parallelism
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tune the throughput vs latency of your pipeline with the ``parallelism`` setting. As a rule of thumb, higher parallelism settings perform better, however ``parallelism == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``parallelism=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage:
|
||||
Tune the throughput vs latency of your pipeline with the ``blocks_per_window`` setting. As a rule of thumb, higher parallelism settings perform better, however ``blocks_per_window == num_blocks`` effectively disables pipelining, since the DatasetPipeline will only contain a single Dataset. The other extreme is setting ``blocks_per_window=1``, which minimizes the latency to initial output but only allows one concurrent transformation task per stage:
|
||||
|
||||
.. image:: dataset-pipeline-3.svg
|
||||
|
||||
|
|
|
@ -638,7 +638,7 @@ Underneath the hood, RaySGD will automatically shard the given dataset.
|
|||
return model
|
||||
|
||||
trainer = Trainer(num_workers=8, backend="torch")
|
||||
dataset = ray.data.read_csv("...").filter().pipeline(length=50)
|
||||
dataset = ray.data.read_csv("...").filter().window(blocks_per_window=50)
|
||||
|
||||
result = trainer.run(
|
||||
train_func,
|
||||
|
@ -741,7 +741,7 @@ A couple caveats:
|
|||
|
||||
# Declare the specification for training.
|
||||
trainer = Trainer(backend="torch", num_workers=12, use_gpu=True)
|
||||
dataset = ray.dataset.pipeline()
|
||||
dataset = ray.dataset.window()
|
||||
|
||||
# Convert this to a trainable.
|
||||
trainable = trainer.to_tune_trainable(training_func, dataset=dataset)
|
||||
|
|
|
@ -1645,11 +1645,14 @@ class Dataset(Generic[T]):
|
|||
return DatasetPipeline(Iterable(self), length=times)
|
||||
|
||||
def pipeline(self, *, parallelism: int = 10) -> "DatasetPipeline[T]":
|
||||
"""Pipeline the dataset execution by splitting its blocks into groups.
|
||||
raise DeprecationWarning("Use .window(n) instead of .pipeline(n)")
|
||||
|
||||
Transformations prior to the call to ``pipeline()`` are evaluated in
|
||||
def window(self, *, blocks_per_window: int = 10) -> "DatasetPipeline[T]":
|
||||
"""Convert this into a DatasetPipeline by windowing over data blocks.
|
||||
|
||||
Transformations prior to the call to ``window()`` are evaluated in
|
||||
bulk on the entire dataset. Transformations done on the returned
|
||||
pipeline are evaluated incrementally per group of blocks as data is
|
||||
pipeline are evaluated incrementally per window of blocks as data is
|
||||
read from the output of the pipeline.
|
||||
|
||||
Pipelining execution allows for output to be read sooner without
|
||||
|
@ -1673,11 +1676,11 @@ class Dataset(Generic[T]):
|
|||
Examples:
|
||||
>>> # Create an inference pipeline.
|
||||
>>> ds = ray.data.read_binary_files(dir)
|
||||
>>> pipe = ds.pipeline(parallelism=10).map(infer)
|
||||
>>> pipe = ds.window(blocks_per_window=10).map(infer)
|
||||
DatasetPipeline(num_stages=2, length=40)
|
||||
|
||||
>>> # The higher the stage parallelism, the shorter the pipeline.
|
||||
>>> pipe = ds.pipeline(parallelism=20).map(infer)
|
||||
>>> pipe = ds.window(blocks_per_window=20).map(infer)
|
||||
DatasetPipeline(num_stages=2, length=20)
|
||||
|
||||
>>> # Outputs can be incrementally read from the pipeline.
|
||||
|
@ -1685,8 +1688,8 @@ class Dataset(Generic[T]):
|
|||
... print(item)
|
||||
|
||||
Args:
|
||||
parallelism: The parallelism (number of blocks) per stage.
|
||||
Increasing parallelism increases pipeline throughput, but also
|
||||
blocks_per_window: The window size (parallelism) in blocks.
|
||||
Increasing window size increases pipeline throughput, but also
|
||||
increases the latency to initial output, since it decreases the
|
||||
length of the pipeline. Setting this to infinity effectively
|
||||
disables pipelining.
|
||||
|
@ -1710,7 +1713,7 @@ class Dataset(Generic[T]):
|
|||
|
||||
class Iterable:
|
||||
def __init__(self, blocks):
|
||||
self._splits = blocks.split(split_size=parallelism)
|
||||
self._splits = blocks.split(split_size=blocks_per_window)
|
||||
|
||||
def __iter__(self):
|
||||
return Iterator(self._splits)
|
||||
|
|
|
@ -40,7 +40,7 @@ class DatasetPipeline(Generic[T]):
|
|||
|
||||
A DatasetPipeline can be created by either repeating a Dataset
|
||||
(``ds.repeat(times=None)``), by turning a single Dataset into a pipeline
|
||||
(``ds.pipeline(parallelism=10)``), or defined explicitly using
|
||||
(``ds.window(blocks_per_window=10)``), or defined explicitly using
|
||||
``DatasetPipeline.from_iterable()``.
|
||||
|
||||
DatasetPipeline supports the all the per-record transforms of Datasets
|
||||
|
@ -57,7 +57,7 @@ class DatasetPipeline(Generic[T]):
|
|||
"""Construct a DatasetPipeline (internal API).
|
||||
|
||||
The constructor is not part of the DatasetPipeline API. Use the
|
||||
``Dataset.repeat()``, ``Dataset.pipeline()``, or
|
||||
``Dataset.repeat()``, ``Dataset.window()``, or
|
||||
``DatasetPipeline.from_iterable()`` methods to construct a pipeline.
|
||||
"""
|
||||
self._base_iterable = base_iterable
|
||||
|
@ -240,6 +240,32 @@ class DatasetPipeline(Generic[T]):
|
|||
for idx in range(n)
|
||||
]
|
||||
|
||||
def window(self, *, blocks_per_window: int) -> "DatasetPipeline[T]":
|
||||
"""Change the windowing (blocks per dataset) of this pipeline.
|
||||
|
||||
Changes the windowing of this pipeline to the specified size. For
|
||||
example, if the current pipeline has two blocks per dataset, and
|
||||
`.window(4)` is requested, adjacent datasets will be merged until each
|
||||
dataset is 4 blocks. If `.window(1)` was requested the datasets will
|
||||
be split into smaller windows.
|
||||
|
||||
Args:
|
||||
blocks_per_window: The new target blocks per window.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def repeat(self, times: int = None) -> "DatasetPipeline[T]":
|
||||
"""Repeat this pipeline a given number or times, or indefinitely.
|
||||
|
||||
This operation is only allowed for pipelines of a finite length. An
|
||||
error will be raised for pipelines of infinite or unknown length.
|
||||
|
||||
Args:
|
||||
times: The number of times to loop over this pipeline, or None
|
||||
to repeat indefinitely.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def schema(self) -> Union[type, "pyarrow.lib.Schema"]:
|
||||
"""Return the schema of the dataset pipeline.
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ class Model:
|
|||
return x
|
||||
|
||||
|
||||
ds = ds.pipeline(parallelism=10) \
|
||||
ds = ds.window(blocks_per_window=10) \
|
||||
.map(preprocess) \
|
||||
.map(Model, compute="actors", num_gpus=1)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ from ray.data.tests.conftest import * # noqa
|
|||
|
||||
def maybe_pipeline(ds, enabled):
|
||||
if enabled:
|
||||
return ds.pipeline(parallelism=1)
|
||||
return ds.window(blocks_per_window=1)
|
||||
else:
|
||||
return ds
|
||||
|
||||
|
|
|
@ -30,14 +30,14 @@ def test_incremental_take(shutdown_only):
|
|||
time.sleep(999999)
|
||||
return x
|
||||
|
||||
pipe = ray.data.range(2).pipeline(parallelism=1)
|
||||
pipe = ray.data.range(2).window(blocks_per_window=1)
|
||||
pipe = pipe.map(block_on_ones)
|
||||
assert pipe.take(1) == [0]
|
||||
|
||||
|
||||
def test_cannot_read_twice(ray_start_regular_shared):
|
||||
ds = ray.data.range(10)
|
||||
pipe = ds.pipeline(parallelism=1)
|
||||
pipe = ds.window(blocks_per_window=1)
|
||||
assert pipe.count() == 10
|
||||
with pytest.raises(RuntimeError):
|
||||
pipe.count()
|
||||
|
@ -52,15 +52,15 @@ def test_cannot_read_twice(ray_start_regular_shared):
|
|||
def test_basic_pipeline(ray_start_regular_shared):
|
||||
ds = ray.data.range(10)
|
||||
|
||||
pipe = ds.pipeline(parallelism=1)
|
||||
pipe = ds.window(blocks_per_window=1)
|
||||
assert str(pipe) == "DatasetPipeline(length=10, num_stages=1)"
|
||||
assert pipe.count() == 10
|
||||
|
||||
pipe = ds.pipeline(parallelism=1).map(lambda x: x).map(lambda x: x)
|
||||
pipe = ds.window(blocks_per_window=1).map(lambda x: x).map(lambda x: x)
|
||||
assert str(pipe) == "DatasetPipeline(length=10, num_stages=3)"
|
||||
assert pipe.take() == list(range(10))
|
||||
|
||||
pipe = ds.pipeline(parallelism=999)
|
||||
pipe = ds.window(blocks_per_window=999)
|
||||
assert str(pipe) == "DatasetPipeline(length=1, num_stages=1)"
|
||||
assert pipe.count() == 10
|
||||
|
||||
|
@ -97,30 +97,30 @@ def test_repartition(ray_start_regular_shared):
|
|||
|
||||
|
||||
def test_iter_batches(ray_start_regular_shared):
|
||||
pipe = ray.data.range(10).pipeline(parallelism=2)
|
||||
pipe = ray.data.range(10).window(blocks_per_window=2)
|
||||
batches = list(pipe.iter_batches())
|
||||
assert len(batches) == 10
|
||||
assert all(len(e) == 1 for e in batches)
|
||||
|
||||
|
||||
def test_iter_datasets(ray_start_regular_shared):
|
||||
pipe = ray.data.range(10).pipeline(parallelism=2)
|
||||
pipe = ray.data.range(10).window(blocks_per_window=2)
|
||||
ds = list(pipe.iter_datasets())
|
||||
assert len(ds) == 5
|
||||
|
||||
pipe = ray.data.range(10).pipeline(parallelism=5)
|
||||
pipe = ray.data.range(10).window(blocks_per_window=5)
|
||||
ds = list(pipe.iter_datasets())
|
||||
assert len(ds) == 2
|
||||
|
||||
|
||||
def test_foreach_dataset(ray_start_regular_shared):
|
||||
pipe = ray.data.range(5).pipeline(parallelism=2)
|
||||
pipe = ray.data.range(5).window(blocks_per_window=2)
|
||||
pipe = pipe.foreach_dataset(lambda ds: ds.map(lambda x: x * 2))
|
||||
assert pipe.take() == [0, 2, 4, 6, 8]
|
||||
|
||||
|
||||
def test_schema(ray_start_regular_shared):
|
||||
pipe = ray.data.range(5).pipeline(parallelism=2)
|
||||
pipe = ray.data.range(5).window(blocks_per_window=2)
|
||||
assert pipe.schema() == int
|
||||
|
||||
|
||||
|
@ -179,7 +179,7 @@ def test_parquet_write(ray_start_regular_shared, tmp_path):
|
|||
df2 = pd.DataFrame({"one": [4, 5, 6], "two": ["e", "f", "g"]})
|
||||
df = pd.concat([df1, df2])
|
||||
ds = ray.data.from_pandas([df1, df2])
|
||||
ds = ds.pipeline(parallelism=1)
|
||||
ds = ds.window(blocks_per_window=1)
|
||||
path = os.path.join(tmp_path, "test_parquet_dir")
|
||||
os.mkdir(path)
|
||||
ds._set_uuid("data")
|
||||
|
|
Loading…
Add table
Reference in a new issue