diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 5ab6c3968..2eb33a9f6 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -5,7 +5,6 @@ import os import ray from ray.util.annotations import DeveloperAPI - # The context singleton on this process. _default_context: "Optional[DatasetContext]" = None _context_lock = threading.Lock() @@ -38,9 +37,6 @@ DEFAULT_USE_PUSH_BASED_SHUFFLE = bool( os.environ.get("RAY_DATASET_PUSH_BASED_SHUFFLE", None) ) -# Whether to use Polars for tabular dataset sorts, groupbys, and aggregations. -DEFAULT_USE_POLARS = False - @DeveloperAPI class DatasetContext: @@ -61,7 +57,6 @@ class DatasetContext: optimize_fuse_shuffle_stages: bool, actor_prefetcher_enabled: bool, use_push_based_shuffle: bool, - use_polars: bool, ): """Private constructor (use get_current() instead).""" self.block_owner = block_owner @@ -73,7 +68,6 @@ class DatasetContext: self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages self.actor_prefetcher_enabled = actor_prefetcher_enabled self.use_push_based_shuffle = use_push_based_shuffle - self.use_polars = use_polars @staticmethod def get_current() -> "DatasetContext": @@ -97,7 +91,6 @@ class DatasetContext: optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES, actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED, use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE, - use_polars=DEFAULT_USE_POLARS, ) if ( diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index a81fb39db..76b986bcc 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -32,8 +32,6 @@ from ray.data.block import ( from ray.data.row import TableRow from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder from ray.data.aggregate import AggregateFn -from ray.data.context import DatasetContext -from ray.data.impl.arrow_ops import transform_polars, transform_pyarrow if TYPE_CHECKING: import pandas @@ -42,21 +40,6 @@ if TYPE_CHECKING: T = TypeVar("T") -# We offload some transformations to polars for performance. -def get_sort_transform(context: DatasetContext) -> Callable: - if context.use_polars: - return transform_polars.sort - else: - return transform_pyarrow.sort - - -def get_concat_and_sort_transform(context: DatasetContext) -> Callable: - if context.use_polars: - return transform_polars.concat_and_sort - else: - return transform_pyarrow.concat_and_sort - - class ArrowRow(TableRow): """ Row of a tabular Dataset backed by a Arrow Table block. @@ -282,35 +265,45 @@ class ArrowBlockAccessor(TableBlockAccessor): # so calling sort_indices() will raise an error. return [self._empty_table() for _ in range(len(boundaries) + 1)] - context = DatasetContext.get_current() - sort = get_sort_transform(context) - col, _ = key[0] - table = sort(self._table, key, descending) + import pyarrow.compute as pac + + indices = pac.sort_indices(self._table, sort_keys=key) + table = self._table.take(indices) if len(boundaries) == 0: return [table] - partitions = [] # For each boundary value, count the number of items that are less # than it. Since the block is sorted, these counts partition the items # such that boundaries[i] <= x < boundaries[i + 1] for each x in # partition[i]. If `descending` is true, `boundaries` would also be # in descending order and we only need to count the number of items # *greater than* the boundary value instead. - if descending: - num_rows = len(table[col]) - bounds = num_rows - np.searchsorted( - table[col], boundaries, sorter=np.arange(num_rows - 1, -1, -1) - ) - else: - bounds = np.searchsorted(table[col], boundaries) - last_idx = 0 - for idx in bounds: + col, _ = key[0] + comp_fn = pac.greater if descending else pac.less + + # TODO(ekl) this is O(n^2) but in practice it's much faster than the + # O(n) algorithm, could be optimized. + boundary_indices = [pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries] + ### Compute the boundary indices in O(n) time via scan. # noqa + # boundary_indices = [] + # remaining = boundaries.copy() + # values = table[col] + # for i, x in enumerate(values): + # while remaining and not comp_fn(x, remaining[0]).as_py(): + # remaining.pop(0) + # boundary_indices.append(i) + # for _ in remaining: + # boundary_indices.append(len(values)) + + ret = [] + prev_i = 0 + for i in boundary_indices: # Slices need to be copied to avoid including the base table # during serialization. - partitions.append(_copy_table(table.slice(last_idx, idx - last_idx))) - last_idx = idx - partitions.append(_copy_table(table.slice(last_idx))) - return partitions + ret.append(_copy_table(table.slice(prev_i, i - prev_i))) + prev_i = i + ret.append(_copy_table(table.slice(prev_i))) + return ret def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]: """Combine rows with the same key into an accumulator. @@ -398,10 +391,9 @@ class ArrowBlockAccessor(TableBlockAccessor): if len(blocks) == 0: ret = ArrowBlockAccessor._empty_table() else: - concat_and_sort = get_concat_and_sort_transform( - DatasetContext.get_current() - ) - ret = concat_and_sort(blocks, key, _descending) + ret = pyarrow.concat_tables(blocks, promote=True) + indices = pyarrow.compute.sort_indices(ret, sort_keys=key) + ret = ret.take(indices) return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build()) @staticmethod diff --git a/python/ray/data/impl/arrow_ops/__init__.py b/python/ray/data/impl/arrow_ops/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/ray/data/impl/arrow_ops/transform_polars.py b/python/ray/data/impl/arrow_ops/transform_polars.py deleted file mode 100644 index 32213e8e4..000000000 --- a/python/ray/data/impl/arrow_ops/transform_polars.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import List, TYPE_CHECKING - -try: - import pyarrow -except ImportError: - pyarrow = None - -try: - import polars as pl -except ImportError: - pl = None - - -if TYPE_CHECKING: - from ray.data.impl.sort import SortKeyT - - -def check_polars_installed(): - if pl is None: - raise ImportError( - "polars not installed. Install with `pip install polars` or set " - "`DatasetContext.use_polars = False` to fall back to pyarrow" - ) - - -def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table": - check_polars_installed() - col, _ = key[0] - df = pl.from_arrow(table) - return df.sort(col, reverse=descending).to_arrow() - - -def concat_and_sort( - blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool -) -> "pyarrow.Table": - check_polars_installed() - col, _ = key[0] - blocks = [pl.from_arrow(block) for block in blocks] - df = pl.concat(blocks).sort(col, reverse=descending) - return df.to_arrow() diff --git a/python/ray/data/impl/arrow_ops/transform_pyarrow.py b/python/ray/data/impl/arrow_ops/transform_pyarrow.py deleted file mode 100644 index f4314eadc..000000000 --- a/python/ray/data/impl/arrow_ops/transform_pyarrow.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import List, TYPE_CHECKING - -try: - import pyarrow -except ImportError: - pyarrow = None - -if TYPE_CHECKING: - from ray.data.impl.sort import SortKeyT - - -def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table": - import pyarrow.compute as pac - - indices = pac.sort_indices(table, sort_keys=key) - return table.take(indices) - - -def concat_and_sort( - blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool -) -> "pyarrow.Table": - ret = pyarrow.concat_tables(blocks, promote=True) - indices = pyarrow.compute.sort_indices(ret, sort_keys=key) - return ret.take(indices) diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index a759c49f5..9bb31a4fc 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -81,7 +81,6 @@ class OnesSource(Datasource): return read_tasks -@pytest.mark.skip(reason="failing after #24523") @pytest.mark.parametrize("lazy_input", [True, False]) def test_memory_release_pipeline(shutdown_only, lazy_input): context = DatasetContext.get_current() diff --git a/python/ray/data/tests/test_sort.py b/python/ray/data/tests/test_sort.py index eab5773d5..d61486651 100644 --- a/python/ray/data/tests/test_sort.py +++ b/python/ray/data/tests/test_sort.py @@ -76,17 +76,12 @@ def test_sort_partition_same_key_to_same_block( @pytest.mark.parametrize("num_items,parallelism", [(100, 1), (1000, 4)]) @pytest.mark.parametrize("use_push_based_shuffle", [False, True]) -@pytest.mark.parametrize("use_polars", [False, True]) -def test_sort_arrow( - ray_start_regular, num_items, parallelism, use_push_based_shuffle, use_polars -): +def test_sort_arrow(ray_start_regular, num_items, parallelism, use_push_based_shuffle): ctx = ray.data.context.DatasetContext.get_current() try: - original_push_based_shuffle = ctx.use_push_based_shuffle + original = ctx.use_push_based_shuffle ctx.use_push_based_shuffle = use_push_based_shuffle - original_use_polars = ctx.use_polars - ctx.use_polars = use_polars a = list(reversed(range(num_items))) b = [f"{x:03}" for x in range(num_items)] @@ -117,22 +112,16 @@ def test_sort_arrow( assert_sorted(ds.sort(key="b"), zip(a, b)) assert_sorted(ds.sort(key="a", descending=True), zip(a, b)) finally: - ctx.use_push_based_shuffle = original_push_based_shuffle - ctx.use_polars = original_use_polars + ctx.use_push_based_shuffle = original @pytest.mark.parametrize("use_push_based_shuffle", [False, True]) -@pytest.mark.parametrize("use_polars", [False, True]) -def test_sort_arrow_with_empty_blocks( - ray_start_regular, use_push_based_shuffle, use_polars -): +def test_sort_arrow_with_empty_blocks(ray_start_regular, use_push_based_shuffle): ctx = ray.data.context.DatasetContext.get_current() try: - original_push_based_shuffle = ctx.use_push_based_shuffle + original = ctx.use_push_based_shuffle ctx.use_push_based_shuffle = use_push_based_shuffle - original_use_polars = ctx.use_polars - ctx.use_polars = use_polars assert ( BlockAccessor.for_block(pa.Table.from_pydict({})).sample(10, "A").num_rows @@ -173,8 +162,7 @@ def test_sort_arrow_with_empty_blocks( ) assert ds.sort("value").count() == 0 finally: - ctx.use_push_based_shuffle = original_push_based_shuffle - ctx.use_polars = original_use_polars + ctx.use_push_based_shuffle = original def test_push_based_shuffle_schedule(): diff --git a/python/requirements.txt b/python/requirements.txt index 61c6f7df5..efe903abe 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -45,8 +45,6 @@ aiorwlock # Requirements for running tests pyarrow >= 6.0.1, < 7.0.0 -# Used for Dataset tests. -polars azure-cli-core==2.29.1 azure-identity==1.7.0 azure-mgmt-compute==23.1.0 diff --git a/release/nightly_tests/dataset/sort.py b/release/nightly_tests/dataset/sort.py index 3b12401b6..4dfd460ff 100644 --- a/release/nightly_tests/dataset/sort.py +++ b/release/nightly_tests/dataset/sort.py @@ -10,7 +10,6 @@ from typing import List from ray.data.impl.arrow_block import ArrowRow from ray.data.impl.util import _check_pyarrow_version from ray.data.block import Block, BlockMetadata -from ray.data.context import DatasetContext from ray.data.datasource import Datasource, ReadTask from ray.internal.internal_api import memory_summary @@ -86,15 +85,9 @@ if __name__ == "__main__": parser.add_argument( "--shuffle", help="shuffle instead of sort", action="store_true" ) - parser.add_argument("--use-polars", action="store_true") args = parser.parse_args() - if args.use_polars and not args.shuffle: - print("Using polars for sort") - ctx = DatasetContext.get_current() - ctx.use_polars = True - num_partitions = int(args.num_partitions) partition_size = int(float(args.partition_size)) print(