[datasets] Use generators for merge stage in push-based shuffle (#25336)

This uses the generators introduced in #25247 to reduce memory usage during the merge stage in push-based shuffle. These tasks merge groups of map outputs, so it fits a generator pattern where we want to return merged outputs one at a time. Verified that this allows for merging more/larger objects at a time than the current list-based version.

I also tried this for the map stage in random_shuffle, but it didn't seem to make a difference in memory usage for Arrow blocks. I think this is probably because Arrow is already doing some zero-copy optimizations when selecting rows?

Also adds a new line to Dataset stats for memory usage. Unfortunately it's hard to get an accurate reading of physical memory usage in Python and this value will probably be an overestimate in a lot of cases. I didn't see a difference before and after this PR for the merge stage, for example. Arguably this field should be opt-in. For 100MB partitions, for example:
```
        Substage 0 read->random_shuffle_map: 10/10 blocks executed
        * Remote wall time: 1.44s min, 3.32s max, 2.57s mean, 25.74s total
        * Remote cpu time: 1.42s min, 2.53s max, 2.03s mean, 20.25s total
        * Worker memory usage (MB): 462 min, 864 max, 552 mean
        * Output num rows: 12500000 min, 12500000 max, 12500000 mean, 125000000 total
        * Output size bytes: 101562500 min, 101562500 max, 101562500 mean, 1015625000 total
        * Tasks per node: 10 min, 10 max, 10 mean; 1 nodes used

        Substage 1 random_shuffle_reduce: 10/10 blocks executed
        * Remote wall time: 1.47s min, 2.94s max, 2.17s mean, 21.69s total
        * Remote cpu time: 1.45s min, 1.88s max, 1.71s mean, 17.09s total
        * Worker memory usage (MB): 462 min, 1047 max, 831 mean
        * Output num rows: 12500000 min, 12500000 max, 12500000 mean, 125000000 total
        * Output size bytes: 101562500 min, 101562500 max, 101562500 mean, 1015625000 total
        * Tasks per node: 10 min, 10 max, 10 mean; 1 nodes used
```


## Checks

- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for https://docs.ray.io/en/master/.
- [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Stephanie Wang 2022-06-17 15:29:24 -04:00 committed by GitHub
parent 293c122302
commit d699351748
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 28 deletions

View file

@ -375,16 +375,29 @@ class PushBasedShufflePlan(ShuffleOp):
len({len(mapper_outputs) for mapper_outputs in all_mapper_outputs}) == 1
), "Received different number of map inputs"
stats = BlockExecStats.builder()
merged_outputs = []
if not reduce_args:
reduce_args = []
for mapper_outputs in zip(*all_mapper_outputs):
num_rows = 0
size_bytes = 0
schema = None
for i, mapper_outputs in enumerate(zip(*all_mapper_outputs)):
block, meta = reduce_fn(*reduce_args, *mapper_outputs)
merged_outputs.append(block)
meta = BlockAccessor.for_block(block).get_metadata(
input_files=None, exec_stats=stats.build()
yield block
block = BlockAccessor.for_block(block)
num_rows += block.num_rows()
size_bytes += block.size_bytes()
schema = block.schema()
del block
yield BlockMetadata(
num_rows=num_rows,
size_bytes=size_bytes,
schema=schema,
input_files=None,
exec_stats=stats.build(),
)
return merged_outputs + [meta]
@staticmethod
def _compute_shuffle_schedule(

View file

@ -1,13 +1,14 @@
from contextlib import contextmanager
from typing import List, Optional, Set, Dict, Tuple, Union
import time
import collections
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
import ray
from ray.data._internal.block_list import BlockList
from ray.data.block import BlockMetadata
from ray.data.context import DatasetContext
from ray.data._internal.block_list import BlockList
def fmt(seconds: float) -> str:
@ -315,6 +316,14 @@ class DatasetStats:
fmt(sum([e.cpu_time_s for e in exec_stats])),
)
out += indent
memory_stats = [round(e.max_rss_bytes / 1024 * 1024, 2) for e in exec_stats]
out += "* Peak heap memory usage (MiB): {} min, {} max, {} mean\n".format(
min(memory_stats),
max(memory_stats),
int(np.mean(memory_stats)),
)
output_num_rows = [m.num_rows for m in blocks if m.num_rows is not None]
if output_num_rows:
out += indent

View file

@ -1,32 +1,35 @@
from dataclasses import dataclass
import resource
import time
from dataclasses import dataclass
from typing import (
TypeVar,
List,
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterator,
Tuple,
Any,
Union,
List,
Optional,
Callable,
TYPE_CHECKING,
Tuple,
TypeVar,
Union,
)
import numpy as np
import ray
from ray.data._internal.util import _check_pyarrow_version
from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
import pandas
import pyarrow
from ray.data import Dataset
from ray.data._internal.block_builder import BlockBuilder
from ray.data.aggregate import AggregateFn
from ray.data import Dataset
import ray
from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI
from ray.data._internal.util import _check_pyarrow_version
T = TypeVar("T")
U = TypeVar("U")
@ -114,6 +117,9 @@ class BlockExecStats:
self.wall_time_s: Optional[float] = None
self.cpu_time_s: Optional[float] = None
self.node_id = ray.runtime_context.get_runtime_context().node_id.hex()
# Max memory usage. May be an overestimate since we do not
# differentiate from previous tasks on the same worker.
self.max_rss_bytes: int = 0
@staticmethod
def builder() -> "_BlockExecStatsBuilder":
@ -144,6 +150,9 @@ class _BlockExecStatsBuilder:
stats = BlockExecStats()
stats.wall_time_s = time.perf_counter() - self.start_time
stats.cpu_time_s = time.process_time() - self.start_cpu
stats.max_rss_bytes = int(
resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1e3
)
return stats
@ -279,8 +288,8 @@ class BlockAccessor(Generic[T]):
def for_block(block: Block) -> "BlockAccessor[T]":
"""Create a block accessor for the given block."""
_check_pyarrow_version()
import pyarrow
import pandas
import pyarrow
if isinstance(block, pyarrow.Table):
from ray.data._internal.arrow_block import ArrowBlockAccessor

View file

@ -1,6 +1,7 @@
import pytest
import re
import pytest
import ray
from ray.data.context import DatasetContext
from ray.tests.conftest import * # noqa
@ -32,6 +33,7 @@ def test_dataset_stats_basic(ray_start_regular_shared):
== """Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -39,6 +41,7 @@ def test_dataset_stats_basic(ray_start_regular_shared):
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -66,6 +69,7 @@ def test_dataset_stats_shuffle(ray_start_regular_shared):
Substage Z read->random_shuffle_map: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -73,6 +77,7 @@ def test_dataset_stats_shuffle(ray_start_regular_shared):
Substage N random_shuffle_reduce: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -82,6 +87,7 @@ Stage N repartition: executed in T
Substage Z repartition_map: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -89,6 +95,7 @@ Stage N repartition: executed in T
Substage N repartition_reduce: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -143,6 +150,7 @@ def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
== """Stage N read->map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -161,6 +169,7 @@ def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
== """Stage N read->map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -172,6 +181,7 @@ Stage N split: N/N blocks split from parent in T
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -195,6 +205,7 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
Stage N read->map_batches: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -202,6 +213,7 @@ Stage N read->map_batches: N/N blocks executed in T
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -212,6 +224,7 @@ Stage N read->map_batches: [execution cached]
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -222,6 +235,7 @@ Stage N read->map_batches: [execution cached]
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -281,6 +295,7 @@ def test_dataset_pipeline_split_stats_basic(ray_start_regular_shared):
Stage N read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
@ -289,6 +304,7 @@ Stage N read: N/N blocks executed in T
Stage N read: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used

View file

@ -4,17 +4,17 @@ import logging
import os
import sys
import time
import pytest
import ray
import ray.cluster_utils
from ray._private.test_utils import (
wait_for_pid_to_exit,
client_test_enabled,
run_string_as_driver,
wait_for_pid_to_exit,
)
import ray
logger = logging.getLogger(__name__)