Revert "[dataset] Use polars for sorting (#24523)" (#24781)

This reverts commit c62e00e.

See if reverts this resolve linux://python/ray/tests:test_actor_advanced failure.
This commit is contained in:
Chen Shen 2022-05-13 12:09:12 -07:00 committed by GitHub
parent cc21979998
commit 2be45fed5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 37 additions and 138 deletions

View file

@ -5,7 +5,6 @@ import os
import ray
from ray.util.annotations import DeveloperAPI
# The context singleton on this process.
_default_context: "Optional[DatasetContext]" = None
_context_lock = threading.Lock()
@ -38,9 +37,6 @@ DEFAULT_USE_PUSH_BASED_SHUFFLE = bool(
os.environ.get("RAY_DATASET_PUSH_BASED_SHUFFLE", None)
)
# Whether to use Polars for tabular dataset sorts, groupbys, and aggregations.
DEFAULT_USE_POLARS = False
@DeveloperAPI
class DatasetContext:
@ -61,7 +57,6 @@ class DatasetContext:
optimize_fuse_shuffle_stages: bool,
actor_prefetcher_enabled: bool,
use_push_based_shuffle: bool,
use_polars: bool,
):
"""Private constructor (use get_current() instead)."""
self.block_owner = block_owner
@ -73,7 +68,6 @@ class DatasetContext:
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
self.actor_prefetcher_enabled = actor_prefetcher_enabled
self.use_push_based_shuffle = use_push_based_shuffle
self.use_polars = use_polars
@staticmethod
def get_current() -> "DatasetContext":
@ -97,7 +91,6 @@ class DatasetContext:
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED,
use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE,
use_polars=DEFAULT_USE_POLARS,
)
if (

View file

@ -32,8 +32,6 @@ from ray.data.block import (
from ray.data.row import TableRow
from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder
from ray.data.aggregate import AggregateFn
from ray.data.context import DatasetContext
from ray.data.impl.arrow_ops import transform_polars, transform_pyarrow
if TYPE_CHECKING:
import pandas
@ -42,21 +40,6 @@ if TYPE_CHECKING:
T = TypeVar("T")
# We offload some transformations to polars for performance.
def get_sort_transform(context: DatasetContext) -> Callable:
if context.use_polars:
return transform_polars.sort
else:
return transform_pyarrow.sort
def get_concat_and_sort_transform(context: DatasetContext) -> Callable:
if context.use_polars:
return transform_polars.concat_and_sort
else:
return transform_pyarrow.concat_and_sort
class ArrowRow(TableRow):
"""
Row of a tabular Dataset backed by a Arrow Table block.
@ -282,35 +265,45 @@ class ArrowBlockAccessor(TableBlockAccessor):
# so calling sort_indices() will raise an error.
return [self._empty_table() for _ in range(len(boundaries) + 1)]
context = DatasetContext.get_current()
sort = get_sort_transform(context)
col, _ = key[0]
table = sort(self._table, key, descending)
import pyarrow.compute as pac
indices = pac.sort_indices(self._table, sort_keys=key)
table = self._table.take(indices)
if len(boundaries) == 0:
return [table]
partitions = []
# For each boundary value, count the number of items that are less
# than it. Since the block is sorted, these counts partition the items
# such that boundaries[i] <= x < boundaries[i + 1] for each x in
# partition[i]. If `descending` is true, `boundaries` would also be
# in descending order and we only need to count the number of items
# *greater than* the boundary value instead.
if descending:
num_rows = len(table[col])
bounds = num_rows - np.searchsorted(
table[col], boundaries, sorter=np.arange(num_rows - 1, -1, -1)
)
else:
bounds = np.searchsorted(table[col], boundaries)
last_idx = 0
for idx in bounds:
col, _ = key[0]
comp_fn = pac.greater if descending else pac.less
# TODO(ekl) this is O(n^2) but in practice it's much faster than the
# O(n) algorithm, could be optimized.
boundary_indices = [pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries]
### Compute the boundary indices in O(n) time via scan. # noqa
# boundary_indices = []
# remaining = boundaries.copy()
# values = table[col]
# for i, x in enumerate(values):
# while remaining and not comp_fn(x, remaining[0]).as_py():
# remaining.pop(0)
# boundary_indices.append(i)
# for _ in remaining:
# boundary_indices.append(len(values))
ret = []
prev_i = 0
for i in boundary_indices:
# Slices need to be copied to avoid including the base table
# during serialization.
partitions.append(_copy_table(table.slice(last_idx, idx - last_idx)))
last_idx = idx
partitions.append(_copy_table(table.slice(last_idx)))
return partitions
ret.append(_copy_table(table.slice(prev_i, i - prev_i)))
prev_i = i
ret.append(_copy_table(table.slice(prev_i)))
return ret
def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]:
"""Combine rows with the same key into an accumulator.
@ -398,10 +391,9 @@ class ArrowBlockAccessor(TableBlockAccessor):
if len(blocks) == 0:
ret = ArrowBlockAccessor._empty_table()
else:
concat_and_sort = get_concat_and_sort_transform(
DatasetContext.get_current()
)
ret = concat_and_sort(blocks, key, _descending)
ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
ret = ret.take(indices)
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())
@staticmethod

View file

@ -1,40 +0,0 @@
from typing import List, TYPE_CHECKING
try:
import pyarrow
except ImportError:
pyarrow = None
try:
import polars as pl
except ImportError:
pl = None
if TYPE_CHECKING:
from ray.data.impl.sort import SortKeyT
def check_polars_installed():
if pl is None:
raise ImportError(
"polars not installed. Install with `pip install polars` or set "
"`DatasetContext.use_polars = False` to fall back to pyarrow"
)
def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table":
check_polars_installed()
col, _ = key[0]
df = pl.from_arrow(table)
return df.sort(col, reverse=descending).to_arrow()
def concat_and_sort(
blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool
) -> "pyarrow.Table":
check_polars_installed()
col, _ = key[0]
blocks = [pl.from_arrow(block) for block in blocks]
df = pl.concat(blocks).sort(col, reverse=descending)
return df.to_arrow()

View file

@ -1,24 +0,0 @@
from typing import List, TYPE_CHECKING
try:
import pyarrow
except ImportError:
pyarrow = None
if TYPE_CHECKING:
from ray.data.impl.sort import SortKeyT
def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table":
import pyarrow.compute as pac
indices = pac.sort_indices(table, sort_keys=key)
return table.take(indices)
def concat_and_sort(
blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool
) -> "pyarrow.Table":
ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
return ret.take(indices)

View file

@ -81,7 +81,6 @@ class OnesSource(Datasource):
return read_tasks
@pytest.mark.skip(reason="failing after #24523")
@pytest.mark.parametrize("lazy_input", [True, False])
def test_memory_release_pipeline(shutdown_only, lazy_input):
context = DatasetContext.get_current()

View file

@ -76,17 +76,12 @@ def test_sort_partition_same_key_to_same_block(
@pytest.mark.parametrize("num_items,parallelism", [(100, 1), (1000, 4)])
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
@pytest.mark.parametrize("use_polars", [False, True])
def test_sort_arrow(
ray_start_regular, num_items, parallelism, use_push_based_shuffle, use_polars
):
def test_sort_arrow(ray_start_regular, num_items, parallelism, use_push_based_shuffle):
ctx = ray.data.context.DatasetContext.get_current()
try:
original_push_based_shuffle = ctx.use_push_based_shuffle
original = ctx.use_push_based_shuffle
ctx.use_push_based_shuffle = use_push_based_shuffle
original_use_polars = ctx.use_polars
ctx.use_polars = use_polars
a = list(reversed(range(num_items)))
b = [f"{x:03}" for x in range(num_items)]
@ -117,22 +112,16 @@ def test_sort_arrow(
assert_sorted(ds.sort(key="b"), zip(a, b))
assert_sorted(ds.sort(key="a", descending=True), zip(a, b))
finally:
ctx.use_push_based_shuffle = original_push_based_shuffle
ctx.use_polars = original_use_polars
ctx.use_push_based_shuffle = original
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
@pytest.mark.parametrize("use_polars", [False, True])
def test_sort_arrow_with_empty_blocks(
ray_start_regular, use_push_based_shuffle, use_polars
):
def test_sort_arrow_with_empty_blocks(ray_start_regular, use_push_based_shuffle):
ctx = ray.data.context.DatasetContext.get_current()
try:
original_push_based_shuffle = ctx.use_push_based_shuffle
original = ctx.use_push_based_shuffle
ctx.use_push_based_shuffle = use_push_based_shuffle
original_use_polars = ctx.use_polars
ctx.use_polars = use_polars
assert (
BlockAccessor.for_block(pa.Table.from_pydict({})).sample(10, "A").num_rows
@ -173,8 +162,7 @@ def test_sort_arrow_with_empty_blocks(
)
assert ds.sort("value").count() == 0
finally:
ctx.use_push_based_shuffle = original_push_based_shuffle
ctx.use_polars = original_use_polars
ctx.use_push_based_shuffle = original
def test_push_based_shuffle_schedule():

View file

@ -45,8 +45,6 @@ aiorwlock
# Requirements for running tests
pyarrow >= 6.0.1, < 7.0.0
# Used for Dataset tests.
polars
azure-cli-core==2.29.1
azure-identity==1.7.0
azure-mgmt-compute==23.1.0

View file

@ -10,7 +10,6 @@ from typing import List
from ray.data.impl.arrow_block import ArrowRow
from ray.data.impl.util import _check_pyarrow_version
from ray.data.block import Block, BlockMetadata
from ray.data.context import DatasetContext
from ray.data.datasource import Datasource, ReadTask
from ray.internal.internal_api import memory_summary
@ -86,15 +85,9 @@ if __name__ == "__main__":
parser.add_argument(
"--shuffle", help="shuffle instead of sort", action="store_true"
)
parser.add_argument("--use-polars", action="store_true")
args = parser.parse_args()
if args.use_polars and not args.shuffle:
print("Using polars for sort")
ctx = DatasetContext.get_current()
ctx.use_polars = True
num_partitions = int(args.num_partitions)
partition_size = int(float(args.partition_size))
print(