[Data][split] use split_at_indices for equal split without locality hints (#26641)

This PR replaces dataset.split(.., equal=True) implementation by dataset.split_at_indices() . My experiments (the script
) showed that dataset.split_at_indices() have more predictable performance than the dataset.split(…)

Concretely: on 10 m5.4xlarge nodes with 5000 iops disk

calling ds.split(81) on 200GB dataset with 400 blocks: the split takes 20-40 seconds, split_at_indices takes ~12 seconds.

calling ds.split(163) on 200GB dataset with 400 blocks, the split takes 40-100 seconds, split_at_indices takes ~24 seconds.

I don’t have much insight of dataset.split implementation, but with dataset.split_at_indices() we are just doing SPREAD to num_split_at_indices tasks, which yield much stable performance.

Note: clean up the usage of experimental locality_hints in #26647
This commit is contained in:
Chen Shen 2022-07-17 22:17:47 -07:00 committed by GitHub
parent 98a07920d3
commit 5ce06ce2c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 4 deletions

View file

@ -867,8 +867,8 @@ class Dataset(Generic[T]):
equal: Whether to guarantee each split has an equal
number of records. This may drop records if they cannot be
divided equally among the splits.
locality_hints: A list of Ray actor handles of size ``n``. The
system will try to co-locate the blocks of the i-th dataset
locality_hints: [Experimental] A list of Ray actor handles of size ``n``.
The system will try to co-locate the blocks of the i-th dataset
with the i-th actor to maximize data locality.
Returns:
@ -877,6 +877,19 @@ class Dataset(Generic[T]):
if n <= 0:
raise ValueError(f"The number of splits {n} is not positive.")
# fallback to split_at_indices for equal split without locality hints.
# simple benchmarks shows spilit_at_indices yields more stable performance.
# https://github.com/ray-project/ray/pull/26641 for more context.
if equal and locality_hints is None:
count = self.count()
split_index = count // n
# we are creating n split_indices which will generate
# n + 1 splits; the last split will at most contains (n - 1)
# rows, which could be safely dropped.
split_indices = [split_index * i for i in range(1, n + 1)]
shards = self.split_at_indices(split_indices)
return shards[:n]
if locality_hints and len(locality_hints) != n:
raise ValueError(
f"The length of locality_hints {len(locality_hints)} "

View file

@ -257,8 +257,8 @@ class DatasetPipeline(Generic[T]):
equal: Whether to guarantee each split has an equal
number of records. This may drop records if they cannot be
divided equally among the splits.
locality_hints: A list of Ray actor handles of size ``n``. The
system will try to co-locate the blocks of the ith pipeline
locality_hints: [Experimental] A list of Ray actor handles of size ``n``.
The system will try to co-locate the blocks of the ith pipeline
shard with the ith actor to maximize data locality.
Returns: