Increase dataset read parallelism by default (#18420)

This commit is contained in:
Eric Liang 2021-09-09 15:07:49 -07:00 committed by GitHub
parent ccc16a46bb
commit 4d2065352b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 7 deletions

View file

@ -1,7 +1,7 @@
.. _datasets_tensor_support: .. _datasets_tensor_support:
Datasets Tensor Support Dataset Tensor Support
======================= ======================
Tensor-typed values Tensor-typed values
------------------- -------------------

View file

@ -62,7 +62,7 @@ Datasource Compatibility Matrices
- ✅ - ✅
* - Spark Dataframe * - Spark Dataframe
- ``ray.data.from_spark()`` - ``ray.data.from_spark()``
- (todo) -
* - Dask Dataframe * - Dask Dataframe
- ``ray.data.from_dask()`` - ``ray.data.from_dask()``
- ✅ - ✅
@ -106,7 +106,7 @@ Datasource Compatibility Matrices
- ✅ - ✅
* - Spark Dataframe * - Spark Dataframe
- ``ds.to_spark()`` - ``ds.to_spark()``
- (todo) -
* - Dask Dataframe * - Dask Dataframe
- ``ds.to_dask()`` - ``ds.to_dask()``
- ✅ - ✅

View file

@ -216,9 +216,10 @@ class DatasetPipeline(Generic[T]):
time.sleep(self.wait_delay_s) time.sleep(self.wait_delay_s)
tries += 1 tries += 1
if tries > self.warn_threshold: if tries > self.warn_threshold:
print("Warning: shard {} of the pipeline has been " print("Warning: reader on shard {} of the pipeline "
"stalled more than {}s waiting for other shards " "has been blocked more than {}s waiting for "
"to catch up.".format( "other readers to catch up. All pipeline shards "
"must be read from concurrently.".format(
self.split_index, self.split_index,
self.wait_delay_s * self.warn_threshold)) self.wait_delay_s * self.warn_threshold))
self.warn_threshold *= 2 self.warn_threshold *= 2

View file

@ -155,6 +155,12 @@ def read_datasource(datasource: Datasource[T],
if ray_remote_args is None: if ray_remote_args is None:
ray_remote_args = {} ray_remote_args = {}
# Increase the read parallelism by default to maximize IO throughput. This
# is particularly important when reading from e.g., remote storage.
if "num_cpus" not in ray_remote_args:
# Note that the too many workers warning triggers at 4x subscription,
# so we go at 0.5 to avoid the warning message.
ray_remote_args["num_cpus"] = 0.5
remote_read = cached_remote_fn(remote_read, **ray_remote_args) remote_read = cached_remote_fn(remote_read, **ray_remote_args)
calls: List[Callable[[], ObjectRef[Block]]] = [] calls: List[Callable[[], ObjectRef[Block]]] = []