[Datasets] Print hierarchical stats for multi-stage operations. (#24119)

The total execution time for multi-stage operations being logged twice in the dataset stats is [confusing to users](https://github.com/ray-project/ray/issues/23915), making it seem like each stage in the operation took the same amount of time. This PR modifies the stats output for multi-stage operations, such that the total execution time is printed out once as a top-level op stats line, with the stats for each of the (sub)stages indented and devoid of the total execution time repeat.

This also opens the door for other op-level stats (e.g. peak memory utilization) and per-substage stats (e.g. total substage execution time).
This commit is contained in:
Clark Zinzow 2022-04-22 16:33:11 -07:00 committed by GitHub
parent 9ee24530ab
commit ea791ab0a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 42 deletions

View file

@ -63,6 +63,7 @@ class _DatasetStatsBuilder:
stats = DatasetStats(
stages=stage_infos,
parent=self.parent,
base_name=self.stage_name,
)
stats.time_total_s = time.perf_counter() - self.start_time
return stats
@ -144,8 +145,9 @@ class DatasetStats:
*,
stages: Dict[str, List[BlockMetadata]],
parent: Union[Optional["DatasetStats"], List["DatasetStats"]],
needs_stats_actor=False,
stats_uuid=None
needs_stats_actor: bool = False,
stats_uuid: str = None,
base_name: str = None,
):
"""Create dataset stats.
@ -159,6 +161,7 @@ class DatasetStats:
datasource (i.e. a LazyBlockList).
stats_uuid: The uuid for the stats, used to fetch the right stats
from the stats actor.
base_name: The name of the base operation for a multi-stage operation.
"""
self.stages: Dict[str, List[BlockMetadata]] = stages
@ -171,6 +174,7 @@ class DatasetStats:
self.number: int = (
0 if not self.parents else max(p.number for p in self.parents) + 1
)
self.base_name = base_name
self.dataset_uuid: str = None
self.time_total_s: float = 0
self.needs_stats_actor = needs_stats_actor
@ -219,19 +223,32 @@ class DatasetStats:
if parent_sum:
out += parent_sum
out += "\n"
first = True
for stage_name, metadata in self.stages.items():
if len(self.stages) == 1:
stage_name, metadata = next(iter(self.stages.items()))
stage_uuid = self.dataset_uuid + stage_name
if first:
first = False
else:
out += "\n"
out += "Stage {} {}: ".format(self.number, stage_name)
if stage_uuid in already_printed:
out += "[execution cached]"
else:
already_printed.add(stage_uuid)
out += self._summarize_blocks(metadata)
out += self._summarize_blocks(metadata, is_substage=False)
elif len(self.stages) > 1:
rounded_total = round(self.time_total_s, 2)
if rounded_total <= 0:
# Handle -0.0 case.
rounded_total = 0
out += "Stage {} {}: executed in {}s\n".format(
self.number, self.base_name, rounded_total
)
for n, (stage_name, metadata) in enumerate(self.stages.items()):
stage_uuid = self.dataset_uuid + stage_name
out += "\n"
out += "\tSubstage {} {}: ".format(n, stage_name)
if stage_uuid in already_printed:
out += "\t[execution cached]"
else:
already_printed.add(stage_uuid)
out += self._summarize_blocks(metadata, is_substage=True)
out += self._summarize_iter()
return out
@ -253,17 +270,22 @@ class DatasetStats:
out += "* Total time: {}\n".format(fmt(self.iter_total_s.get()))
return out
def _summarize_blocks(self, blocks: List[BlockMetadata]) -> str:
def _summarize_blocks(self, blocks: List[BlockMetadata], is_substage: bool) -> str:
exec_stats = [m.exec_stats for m in blocks if m.exec_stats is not None]
rounded_total = round(self.time_total_s, 2)
if rounded_total <= 0:
# Handle -0.0 case.
rounded_total = 0
out = "{}/{} blocks executed in {}s\n".format(
len(exec_stats), len(blocks), rounded_total
)
indent = "\t" if is_substage else ""
if is_substage:
out = "{}/{} blocks executed\n".format(len(exec_stats), len(blocks))
else:
rounded_total = round(self.time_total_s, 2)
if rounded_total <= 0:
# Handle -0.0 case.
rounded_total = 0
out = "{}/{} blocks executed in {}s\n".format(
len(exec_stats), len(blocks), rounded_total
)
if exec_stats:
out += indent
out += "* Remote wall time: {} min, {} max, {} mean, {} total\n".format(
fmt(min([e.wall_time_s for e in exec_stats])),
fmt(max([e.wall_time_s for e in exec_stats])),
@ -271,6 +293,7 @@ class DatasetStats:
fmt(sum([e.wall_time_s for e in exec_stats])),
)
out += indent
out += "* Remote cpu time: {} min, {} max, {} mean, {} total\n".format(
fmt(min([e.cpu_time_s for e in exec_stats])),
fmt(max([e.cpu_time_s for e in exec_stats])),
@ -280,6 +303,7 @@ class DatasetStats:
output_num_rows = [m.num_rows for m in blocks if m.num_rows is not None]
if output_num_rows:
out += indent
out += "* Output num rows: {} min, {} max, {} mean, {} total\n".format(
min(output_num_rows),
max(output_num_rows),
@ -289,6 +313,7 @@ class DatasetStats:
output_size_bytes = [m.size_bytes for m in blocks if m.size_bytes is not None]
if output_size_bytes:
out += indent
out += "* Output size bytes: {} min, {} max, {} mean, {} total\n".format(
min(output_size_bytes),
max(output_size_bytes),
@ -300,6 +325,7 @@ class DatasetStats:
node_counts = collections.defaultdict(int)
for s in exec_stats:
node_counts[s.node_id] += 1
out += indent
out += "* Tasks per node: {} min, {} max, {} mean; {} nodes used\n".format(
min(node_counts.values()),
max(node_counts.values()),

View file

@ -13,7 +13,9 @@ def canonicalize(stats: str) -> str:
s2 = re.sub(" [0]+(\.[0]+)?", " Z", s1)
# Other numerics.
s3 = re.sub("[0-9]+(\.[0-9]+)?", "N", s2)
return s3
# Replace tabs with spaces.
s4 = re.sub("\t", " ", s3)
return s4
def test_dataset_stats_basic(ray_start_regular_shared):
@ -59,33 +61,37 @@ def test_dataset_stats_shuffle(ray_start_regular_shared):
stats = canonicalize(ds.stats())
assert (
stats
== """Stage N read->random_shuffle_map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
== """Stage N read->random_shuffle: executed in T
Stage N random_shuffle_reduce: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Substage Z read->random_shuffle_map: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N repartition_map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Substage N random_shuffle_reduce: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N repartition_reduce: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N repartition: executed in T
Substage Z repartition_map: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Substage N repartition_reduce: N/N blocks executed
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)