[Datasets] Preserve cached block metadata on LazyBlockList splits. (#25745)

Preserves cached block metadata on LazyBlockList splits. Before this PR, after these splits, all block metadata would have to be re-fetched.
This commit is contained in:
Clark Zinzow 2022-06-16 12:36:25 -07:00 committed by GitHub
parent d98adbc448
commit 04280d6e4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,11 +1,14 @@
import math
from typing import List, Iterator, Tuple, Optional, Dict, Any
import uuid
from typing import Any, Dict, Iterator, List, Optional, Tuple
import numpy as np
import ray
from ray.types import ObjectRef
from ray.data._internal.block_list import BlockList
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats, _get_or_create_stats_actor
from ray.data.block import (
Block,
BlockAccessor,
@ -16,10 +19,7 @@ from ray.data.block import (
)
from ray.data.context import DatasetContext
from ray.data.datasource import ReadTask
from ray.data._internal.block_list import BlockList
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.stats import DatasetStats, _get_or_create_stats_actor
from ray.types import ObjectRef
class LazyBlockList(BlockList):
@ -148,13 +148,17 @@ class LazyBlockList(BlockList):
block_partition_meta_refs = np.array_split(
self._block_partition_meta_refs, num_splits
)
cached_metadata = np.array_split(self._cached_metadata, num_splits)
output = []
for t, b, m in zip(tasks, block_partition_refs, block_partition_meta_refs):
for t, b, m, c in zip(
tasks, block_partition_refs, block_partition_meta_refs, cached_metadata
):
output.append(
LazyBlockList(
t.tolist(),
b.tolist(),
m.tolist(),
c.tolist(),
)
)
return output
@ -162,12 +166,13 @@ class LazyBlockList(BlockList):
# Note: does not force execution prior to splitting.
def split_by_bytes(self, bytes_per_split: int) -> List["BlockList"]:
output = []
cur_tasks, cur_blocks, cur_blocks_meta = [], [], []
cur_tasks, cur_blocks, cur_blocks_meta, cur_cached_meta = [], [], [], []
cur_size = 0
for t, b, bm in zip(
for t, b, bm, c in zip(
self._tasks,
self._block_partition_refs,
self._block_partition_meta_refs,
self._cached_metadata,
):
m = t.get_metadata()
if m.size_bytes is None:
@ -177,16 +182,24 @@ class LazyBlockList(BlockList):
size = m.size_bytes
if cur_blocks and cur_size + size > bytes_per_split:
output.append(
LazyBlockList(cur_tasks, cur_blocks, cur_blocks_meta),
LazyBlockList(
cur_tasks,
cur_blocks,
cur_blocks_meta,
cur_cached_meta,
),
)
cur_tasks, cur_blocks, cur_blocks_meta = [], [], []
cur_tasks, cur_blocks, cur_blocks_meta, cur_cached_meta = [], [], [], []
cur_size = 0
cur_tasks.append(t)
cur_blocks.append(b)
cur_blocks_meta.append(bm)
cur_cached_meta.append(c)
cur_size += size
if cur_blocks:
output.append(LazyBlockList(cur_tasks, cur_blocks, cur_blocks_meta))
output.append(
LazyBlockList(cur_tasks, cur_blocks, cur_blocks_meta, cur_cached_meta)
)
return output
# Note: does not force execution prior to division.
@ -195,11 +208,13 @@ class LazyBlockList(BlockList):
self._tasks[:part_idx],
self._block_partition_refs[:part_idx],
self._block_partition_meta_refs[:part_idx],
self._cached_metadata[:part_idx],
)
right = LazyBlockList(
self._tasks[part_idx:],
self._block_partition_refs[part_idx:],
self._block_partition_meta_refs[part_idx:],
self._cached_metadata[part_idx:],
)
return left, right