[Datasets] Add basic stats instrumentation of split_at_indices(). (#24179)

This PR adds basic stats instrumentation of split_at_indices(), the first stage in fully instrumenting split operations. See https://github.com/ray-project/ray/issues/24178 for future steps.
This commit is contained in:
Clark Zinzow 2022-04-26 09:49:48 -07:00 committed by GitHub
parent 27e7c284ee
commit 07112b4146
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 11 deletions

View file

@ -2958,6 +2958,7 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
def _split(
self, index: int, return_right_half: bool
) -> ("Dataset[T]", "Dataset[T]"):
start_time = time.perf_counter()
get_num_rows = cached_remote_fn(_get_num_rows)
split_block = cached_remote_fn(_split_block, num_returns=4)
@ -2993,19 +2994,50 @@ List[str]]]): The names of the columns to use as the features. Can be a list of
right_metadata.append(ray.get(m1))
count += num_rows
split_duration = time.perf_counter() - start_time
left_meta_for_stats = [
BlockMetadata(
num_rows=m.num_rows,
size_bytes=m.size_bytes,
schema=m.schema,
input_files=m.input_files,
exec_stats=None,
)
for m in left_metadata
]
left_dataset_stats = DatasetStats(
stages={"split": left_meta_for_stats},
parent=self._plan.stats(),
)
left_dataset_stats.time_total_s = split_duration
left = Dataset(
ExecutionPlan(
BlockList(left_blocks, left_metadata),
self._plan.stats().child_TODO("split"),
left_dataset_stats,
),
self._epoch,
self._lazy,
)
if return_right_half:
right_meta_for_stats = [
BlockMetadata(
num_rows=m.num_rows,
size_bytes=m.size_bytes,
schema=m.schema,
input_files=m.input_files,
exec_stats=None,
)
for m in right_metadata
]
right_dataset_stats = DatasetStats(
stages={"split": right_meta_for_stats},
parent=self._plan.stats(),
)
right_dataset_stats.time_total_s = split_duration
right = Dataset(
ExecutionPlan(
BlockList(right_blocks, right_metadata),
self._plan.stats().child_TODO("split"),
right_dataset_stats,
),
self._epoch,
self._lazy,

View file

@ -165,12 +165,9 @@ class DatasetStats:
"""
self.stages: Dict[str, List[BlockMetadata]] = stages
self.parents: List["DatasetStats"] = []
if parent:
if isinstance(parent, list):
self.parents.extend(parent)
else:
self.parents.append(parent)
if parent is not None and not isinstance(parent, list):
parent = [parent]
self.parents: List["DatasetStats"] = parent
self.number: int = (
0 if not self.parents else max(p.number for p in self.parents) + 1
)
@ -280,9 +277,22 @@ class DatasetStats:
if rounded_total <= 0:
# Handle -0.0 case.
rounded_total = 0
out = "{}/{} blocks executed in {}s\n".format(
len(exec_stats), len(blocks), rounded_total
)
if exec_stats:
out = "{}/{} blocks executed in {}s".format(
len(exec_stats), len(blocks), rounded_total
)
else:
out = ""
if len(exec_stats) < len(blocks):
if exec_stats:
out += ", "
num_inherited = len(blocks) - len(exec_stats)
out += "{}/{} blocks split from parent".format(
num_inherited, len(blocks)
)
if not exec_stats:
out += " in {}s".format(rounded_total)
out += "\n"
if exec_stats:
out += indent

View file

@ -150,6 +150,35 @@ def test_dataset_stats_read_parquet(ray_start_regular_shared, tmp_path):
)
def test_dataset_split_stats(ray_start_regular_shared, tmp_path):
ds = ray.data.range(100, parallelism=10).map(lambda x: x + 1)
dses = ds.split_at_indices([50])
dses = [ds.map(lambda x: x + 1) for ds in dses]
for ds_ in dses:
stats = canonicalize(ds_.stats())
assert (
stats
== """Stage N read->map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
Stage N split: N/N blocks split from parent in T
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
Stage N map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
"""
)
def test_dataset_pipeline_stats_basic(ray_start_regular_shared):
context = DatasetContext.get_current()
context.optimize_fuse_stages = True