[data] Cleanup Block type by dropping Generic[T] (#17276)

* wip

* update

* update

* quotes
This commit is contained in:
Eric Liang 2021-07-23 09:23:06 -07:00 committed by GitHub
parent ded239205f
commit df7fe8dd6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 44 additions and 52 deletions

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 142 KiB

After

Width:  |  Height:  |  Size: 148 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 660 KiB

After

Width:  |  Height:  |  Size: 670 KiB

View file

@ -10,18 +10,11 @@ from ray.util.annotations import DeveloperAPI
T = TypeVar("T")
# TODO(ekl) this is a dummy generic ref type for documentation purposes only.
# It adds Generic[T] to pyarrow.Table so we can define Block[T] below.
class _ArrowTable(Generic[T]):
pass
# Represents a batch of rows to be stored in the Ray object store.
#
# Block data can be accessed in a uniform way via ``BlockAccessors`` such as
# ``SimpleBlockAccessor`` and ``ArrowBlockAccessor``.
Block = Union[List[T], _ArrowTable[T]]
Block = Union[List[T], "pyarrow.Table"]
@DeveloperAPI
@ -68,7 +61,7 @@ class BlockAccessor(Generic[T]):
"""Iterate over the rows of this block."""
raise NotImplementedError
def slice(self, start: int, end: int, copy: bool) -> "Block[T]":
def slice(self, start: int, end: int, copy: bool) -> Block:
"""Return a slice of this block.
Args:
@ -111,7 +104,7 @@ class BlockAccessor(Generic[T]):
raise NotImplementedError
@staticmethod
def for_block(block: Block[T]) -> "BlockAccessor[T]":
def for_block(block: Block) -> "BlockAccessor[T]":
"""Create a block accessor for the given block."""
import pyarrow

View file

@ -45,7 +45,7 @@ logger = logging.getLogger(__name__)
class Dataset(Generic[T]):
"""Implements a distributed Arrow dataset.
Datasets are implemented as a list of ``ObjectRef[Block[T]]``. The block
Datasets are implemented as a list of ``ObjectRef[Block]``. The block
also determines the unit of parallelism. The default block type is the
``pyarrow.Table``. Arrow-incompatible objects are held in ``list`` blocks.
@ -105,7 +105,7 @@ class Dataset(Generic[T]):
fn = cache_wrapper(fn)
def transform(block: Block[T]) -> Block[U]:
def transform(block: Block) -> Block:
block = BlockAccessor.for_block(block)
builder = DelegatingArrowBlockBuilder()
for row in block.iter_rows():
@ -166,7 +166,7 @@ class Dataset(Generic[T]):
fn = cache_wrapper(fn)
def transform(block: Block[T]) -> Block[U]:
def transform(block: Block) -> Block:
block = BlockAccessor.for_block(block)
total_rows = block.num_rows()
max_batch_size = batch_size
@ -233,7 +233,7 @@ class Dataset(Generic[T]):
fn = cache_wrapper(fn)
def transform(block: Block[T]) -> Block[U]:
def transform(block: Block) -> Block:
block = BlockAccessor.for_block(block)
builder = DelegatingArrowBlockBuilder()
for row in block.iter_rows():
@ -270,7 +270,7 @@ class Dataset(Generic[T]):
fn = cache_wrapper(fn)
def transform(block: Block[T]) -> Block[T]:
def transform(block: Block) -> Block:
block = BlockAccessor.for_block(block)
builder = block.builder()
for row in block.iter_rows():
@ -497,13 +497,13 @@ class Dataset(Generic[T]):
"""
@ray.remote
def get_num_rows(block: Block[T]) -> int:
def get_num_rows(block: Block) -> int:
block = BlockAccessor.for_block(block)
return block.num_rows()
@ray.remote(num_returns=2)
def truncate(block: Block[T], meta: BlockMetadata,
count: int) -> (Block[T], BlockMetadata):
def truncate(block: Block, meta: BlockMetadata,
count: int) -> (Block, BlockMetadata):
block = BlockAccessor.for_block(block)
logger.debug("Truncating last block to size: {}".format(count))
new_block = block.slice(0, count, copy=True)
@ -583,7 +583,7 @@ class Dataset(Generic[T]):
return meta_count
@ray.remote
def count(block: Block[T]) -> int:
def count(block: Block) -> int:
block = BlockAccessor.for_block(block)
return block.num_rows()
@ -599,7 +599,7 @@ class Dataset(Generic[T]):
"""
@ray.remote
def agg(block: Block[T]) -> int:
def agg(block: Block) -> int:
block = BlockAccessor.for_block(block)
return sum(block.iter_rows())
@ -1151,7 +1151,7 @@ class Dataset(Generic[T]):
import pyarrow
return isinstance(block, pyarrow.Table)
blocks: List[ObjectRef[Block[T]]] = list(self._blocks)
blocks: List[ObjectRef[Block]] = list(self._blocks)
is_arrow = ray.get(check_is_arrow.remote(blocks[0]))
if is_arrow:
@ -1165,7 +1165,7 @@ class Dataset(Generic[T]):
return [block_to_df.remote(block) for block in self._blocks]
@DeveloperAPI
def get_blocks(self) -> List[ObjectRef["Block"]]:
def get_blocks(self) -> List[ObjectRef[Block]]:
"""Get a list of references to the underlying blocks of this dataset.
This function can be used for zero-copy access to the data.
@ -1202,7 +1202,7 @@ class Dataset(Generic[T]):
def _block_sizes(self) -> List[int]:
@ray.remote
def query(block: Block[T]) -> int:
def query(block: Block) -> int:
block = BlockAccessor.for_block(block)
return block.num_rows()

View file

@ -5,7 +5,8 @@ import numpy as np
import ray
from ray.types import ObjectRef
from ray.experimental.data.block import Block, BlockAccessor, BlockMetadata, T
from ray.experimental.data.block import Block, BlockAccessor, \
BlockMetadata, T
from ray.experimental.data.impl.arrow_block import ArrowRow
from ray.util.annotations import PublicAPI
@ -38,7 +39,7 @@ class Datasource(Generic[T]):
"""
raise NotImplementedError
def prepare_write(self, blocks: List[ObjectRef[Block[T]]],
def prepare_write(self, blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
**write_args) -> List["WriteTask[T]"]:
"""Return the list of tasks needed to perform a write.
@ -86,7 +87,7 @@ class Datasource(Generic[T]):
@PublicAPI(stability="beta")
class ReadTask(Callable[[], Block[T]]):
class ReadTask(Callable[[], Block]):
"""A function used to read a block of a dataset.
Read tasks are generated by ``datasource.prepare_read()``, and return
@ -96,15 +97,14 @@ class ReadTask(Callable[[], Block[T]]):
Ray will execute read tasks in remote functions to parallelize execution.
"""
def __init__(self, read_fn: Callable[[], Block[T]],
metadata: BlockMetadata):
def __init__(self, read_fn: Callable[[], Block], metadata: BlockMetadata):
self._metadata = metadata
self._read_fn = read_fn
def get_metadata(self) -> BlockMetadata:
return self._metadata
def __call__(self) -> Block[T]:
def __call__(self) -> Block:
return self._read_fn()
@ -142,7 +142,7 @@ class RangeDatasource(Datasource[Union[ArrowRow, int]]):
# Example of a read task. In a real datasource, this would pull data
# from an external system instead of generating dummy data.
def make_block(start: int, count: int) -> Block[Union[ArrowRow, int]]:
def make_block(start: int, count: int) -> Block:
if use_arrow:
return pyarrow.Table.from_arrays(
[np.arange(start, start + count)], names=["value"])
@ -188,7 +188,7 @@ class DummyOutputDatasource(Datasource[Union[ArrowRow, int]]):
self.rows_written = 0
self.enabled = True
def write(self, block: Block[T]) -> str:
def write(self, block: Block) -> str:
block = BlockAccessor.for_block(block)
if not self.enabled:
raise ValueError("disabled")
@ -205,7 +205,7 @@ class DummyOutputDatasource(Datasource[Union[ArrowRow, int]]):
self.num_ok = 0
self.num_failed = 0
def prepare_write(self, blocks: List[ObjectRef[Block[T]]],
def prepare_write(self, blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
**write_args) -> List["WriteTask[T]"]:
tasks = []

View file

@ -63,12 +63,12 @@ class DelegatingArrowBlockBuilder(BlockBuilder[T]):
self._builder = SimpleBlockBuilder()
self._builder.add(item)
def add_block(self, block: Block[T]) -> None:
def add_block(self, block: Block) -> None:
if self._builder is None:
self._builder = BlockAccessor.for_block(block).builder()
self._builder.add_block(block)
def build(self) -> Block[T]:
def build(self) -> Block:
if self._builder is None:
self._builder = ArrowBlockBuilder()
return self._builder.build()
@ -101,7 +101,7 @@ class ArrowBlockBuilder(BlockBuilder[T]):
self._tables.append(block)
self._num_rows += block.num_rows
def build(self) -> Block[T]:
def build(self) -> Block:
if self._columns:
tables = [pyarrow.Table.from_pydict(self._columns)]
else:

View file

@ -15,11 +15,11 @@ class BlockBuilder(Generic[T]):
"""Append a single row to the block being built."""
raise NotImplementedError
def add_block(self, block: "Block[T]") -> None:
def add_block(self, block: Block) -> None:
"""Append an entire block to the block being built."""
raise NotImplementedError
def build(self) -> "Block[T]":
def build(self) -> Block:
"""Build the block."""
raise NotImplementedError
@ -35,7 +35,7 @@ class SimpleBlockBuilder(BlockBuilder[T]):
assert isinstance(block, list), block
self._items.extend(block)
def build(self) -> "Block[T]":
def build(self) -> Block:
return list(self._items)

View file

@ -1,11 +1,11 @@
from typing import Iterable, List
from ray.types import ObjectRef
from ray.experimental.data.block import Block, BlockMetadata, T
from ray.experimental.data.block import Block, BlockMetadata
class BlockList(Iterable[ObjectRef[Block[T]]]):
def __init__(self, blocks: List[ObjectRef[Block[T]]],
class BlockList(Iterable[ObjectRef[Block]]):
def __init__(self, blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata]):
assert len(blocks) == len(metadata), (blocks, metadata)
self._blocks = blocks

View file

@ -15,7 +15,7 @@ CallableClass = type
class ComputeStrategy:
def apply(self, fn: Any,
blocks: Iterable[Block[T]]) -> Iterable[ObjectRef[Block]]:
blocks: Iterable[Block]) -> Iterable[ObjectRef[Block]]:
raise NotImplementedError
@ -51,7 +51,7 @@ class TaskPool(ComputeStrategy):
class ActorPool(ComputeStrategy):
def apply(self, fn: Any, remote_args: dict,
blocks: Iterable[Block[T]]) -> Iterable[ObjectRef[Block]]:
blocks: Iterable[Block]) -> Iterable[ObjectRef[Block]]:
map_bar = ProgressBar("Map Progress", total=len(blocks))
@ -60,8 +60,8 @@ class ActorPool(ComputeStrategy):
return "ok"
@ray.method(num_returns=2)
def process_block(self, block: Block[T], meta: BlockMetadata
) -> (Block[U], BlockMetadata):
def process_block(self, block: Block,
meta: BlockMetadata) -> (Block, BlockMetadata):
new_block = fn(block)
accessor = BlockAccessor.for_block(new_block)
new_metadata = BlockMetadata(

View file

@ -15,7 +15,7 @@ def simple_shuffle(input_blocks: BlockList[T],
input_num_blocks = len(input_blocks)
@ray.remote(num_returns=output_num_blocks)
def shuffle_map(block: Block[T]) -> List[Block[T]]:
def shuffle_map(block: Block) -> List[Block]:
block = BlockAccessor.for_block(block)
slice_sz = max(1, math.ceil(block.num_rows() / output_num_blocks))
slices = []
@ -31,8 +31,7 @@ def simple_shuffle(input_blocks: BlockList[T],
return slices
@ray.remote(num_returns=2)
def shuffle_reduce(
*mapper_outputs: List[Block[T]]) -> (Block[T], BlockMetadata):
def shuffle_reduce(*mapper_outputs: List[Block]) -> (Block, BlockMetadata):
builder = DelegatingArrowBlockBuilder()
assert len(mapper_outputs) == input_num_blocks
for block in mapper_outputs:

View file

@ -118,10 +118,10 @@ def read_datasource(datasource: Datasource[T],
read_tasks = datasource.prepare_read(parallelism, **read_args)
@ray.remote
def remote_read(task: ReadTask) -> Block[T]:
def remote_read(task: ReadTask) -> Block:
return task()
calls: List[Callable[[], ObjectRef[Block[T]]]] = []
calls: List[Callable[[], ObjectRef[Block]]] = []
metadata: List[BlockMetadata] = []
for task in read_tasks: