mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
This reverts commit c62e00e
.
See if reverts this resolve linux://python/ray/tests:test_actor_advanced failure.
This commit is contained in:
parent
cc21979998
commit
2be45fed5e
9 changed files with 37 additions and 138 deletions
|
@ -5,7 +5,6 @@ import os
|
||||||
import ray
|
import ray
|
||||||
from ray.util.annotations import DeveloperAPI
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
|
||||||
|
|
||||||
# The context singleton on this process.
|
# The context singleton on this process.
|
||||||
_default_context: "Optional[DatasetContext]" = None
|
_default_context: "Optional[DatasetContext]" = None
|
||||||
_context_lock = threading.Lock()
|
_context_lock = threading.Lock()
|
||||||
|
@ -38,9 +37,6 @@ DEFAULT_USE_PUSH_BASED_SHUFFLE = bool(
|
||||||
os.environ.get("RAY_DATASET_PUSH_BASED_SHUFFLE", None)
|
os.environ.get("RAY_DATASET_PUSH_BASED_SHUFFLE", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Whether to use Polars for tabular dataset sorts, groupbys, and aggregations.
|
|
||||||
DEFAULT_USE_POLARS = False
|
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
class DatasetContext:
|
class DatasetContext:
|
||||||
|
@ -61,7 +57,6 @@ class DatasetContext:
|
||||||
optimize_fuse_shuffle_stages: bool,
|
optimize_fuse_shuffle_stages: bool,
|
||||||
actor_prefetcher_enabled: bool,
|
actor_prefetcher_enabled: bool,
|
||||||
use_push_based_shuffle: bool,
|
use_push_based_shuffle: bool,
|
||||||
use_polars: bool,
|
|
||||||
):
|
):
|
||||||
"""Private constructor (use get_current() instead)."""
|
"""Private constructor (use get_current() instead)."""
|
||||||
self.block_owner = block_owner
|
self.block_owner = block_owner
|
||||||
|
@ -73,7 +68,6 @@ class DatasetContext:
|
||||||
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
|
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
|
||||||
self.actor_prefetcher_enabled = actor_prefetcher_enabled
|
self.actor_prefetcher_enabled = actor_prefetcher_enabled
|
||||||
self.use_push_based_shuffle = use_push_based_shuffle
|
self.use_push_based_shuffle = use_push_based_shuffle
|
||||||
self.use_polars = use_polars
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_current() -> "DatasetContext":
|
def get_current() -> "DatasetContext":
|
||||||
|
@ -97,7 +91,6 @@ class DatasetContext:
|
||||||
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
|
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
|
||||||
actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED,
|
actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED,
|
||||||
use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE,
|
use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE,
|
||||||
use_polars=DEFAULT_USE_POLARS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -32,8 +32,6 @@ from ray.data.block import (
|
||||||
from ray.data.row import TableRow
|
from ray.data.row import TableRow
|
||||||
from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder
|
from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder
|
||||||
from ray.data.aggregate import AggregateFn
|
from ray.data.aggregate import AggregateFn
|
||||||
from ray.data.context import DatasetContext
|
|
||||||
from ray.data.impl.arrow_ops import transform_polars, transform_pyarrow
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pandas
|
import pandas
|
||||||
|
@ -42,21 +40,6 @@ if TYPE_CHECKING:
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
# We offload some transformations to polars for performance.
|
|
||||||
def get_sort_transform(context: DatasetContext) -> Callable:
|
|
||||||
if context.use_polars:
|
|
||||||
return transform_polars.sort
|
|
||||||
else:
|
|
||||||
return transform_pyarrow.sort
|
|
||||||
|
|
||||||
|
|
||||||
def get_concat_and_sort_transform(context: DatasetContext) -> Callable:
|
|
||||||
if context.use_polars:
|
|
||||||
return transform_polars.concat_and_sort
|
|
||||||
else:
|
|
||||||
return transform_pyarrow.concat_and_sort
|
|
||||||
|
|
||||||
|
|
||||||
class ArrowRow(TableRow):
|
class ArrowRow(TableRow):
|
||||||
"""
|
"""
|
||||||
Row of a tabular Dataset backed by a Arrow Table block.
|
Row of a tabular Dataset backed by a Arrow Table block.
|
||||||
|
@ -282,35 +265,45 @@ class ArrowBlockAccessor(TableBlockAccessor):
|
||||||
# so calling sort_indices() will raise an error.
|
# so calling sort_indices() will raise an error.
|
||||||
return [self._empty_table() for _ in range(len(boundaries) + 1)]
|
return [self._empty_table() for _ in range(len(boundaries) + 1)]
|
||||||
|
|
||||||
context = DatasetContext.get_current()
|
import pyarrow.compute as pac
|
||||||
sort = get_sort_transform(context)
|
|
||||||
col, _ = key[0]
|
indices = pac.sort_indices(self._table, sort_keys=key)
|
||||||
table = sort(self._table, key, descending)
|
table = self._table.take(indices)
|
||||||
if len(boundaries) == 0:
|
if len(boundaries) == 0:
|
||||||
return [table]
|
return [table]
|
||||||
|
|
||||||
partitions = []
|
|
||||||
# For each boundary value, count the number of items that are less
|
# For each boundary value, count the number of items that are less
|
||||||
# than it. Since the block is sorted, these counts partition the items
|
# than it. Since the block is sorted, these counts partition the items
|
||||||
# such that boundaries[i] <= x < boundaries[i + 1] for each x in
|
# such that boundaries[i] <= x < boundaries[i + 1] for each x in
|
||||||
# partition[i]. If `descending` is true, `boundaries` would also be
|
# partition[i]. If `descending` is true, `boundaries` would also be
|
||||||
# in descending order and we only need to count the number of items
|
# in descending order and we only need to count the number of items
|
||||||
# *greater than* the boundary value instead.
|
# *greater than* the boundary value instead.
|
||||||
if descending:
|
col, _ = key[0]
|
||||||
num_rows = len(table[col])
|
comp_fn = pac.greater if descending else pac.less
|
||||||
bounds = num_rows - np.searchsorted(
|
|
||||||
table[col], boundaries, sorter=np.arange(num_rows - 1, -1, -1)
|
# TODO(ekl) this is O(n^2) but in practice it's much faster than the
|
||||||
)
|
# O(n) algorithm, could be optimized.
|
||||||
else:
|
boundary_indices = [pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries]
|
||||||
bounds = np.searchsorted(table[col], boundaries)
|
### Compute the boundary indices in O(n) time via scan. # noqa
|
||||||
last_idx = 0
|
# boundary_indices = []
|
||||||
for idx in bounds:
|
# remaining = boundaries.copy()
|
||||||
|
# values = table[col]
|
||||||
|
# for i, x in enumerate(values):
|
||||||
|
# while remaining and not comp_fn(x, remaining[0]).as_py():
|
||||||
|
# remaining.pop(0)
|
||||||
|
# boundary_indices.append(i)
|
||||||
|
# for _ in remaining:
|
||||||
|
# boundary_indices.append(len(values))
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
prev_i = 0
|
||||||
|
for i in boundary_indices:
|
||||||
# Slices need to be copied to avoid including the base table
|
# Slices need to be copied to avoid including the base table
|
||||||
# during serialization.
|
# during serialization.
|
||||||
partitions.append(_copy_table(table.slice(last_idx, idx - last_idx)))
|
ret.append(_copy_table(table.slice(prev_i, i - prev_i)))
|
||||||
last_idx = idx
|
prev_i = i
|
||||||
partitions.append(_copy_table(table.slice(last_idx)))
|
ret.append(_copy_table(table.slice(prev_i)))
|
||||||
return partitions
|
return ret
|
||||||
|
|
||||||
def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]:
|
def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]:
|
||||||
"""Combine rows with the same key into an accumulator.
|
"""Combine rows with the same key into an accumulator.
|
||||||
|
@ -398,10 +391,9 @@ class ArrowBlockAccessor(TableBlockAccessor):
|
||||||
if len(blocks) == 0:
|
if len(blocks) == 0:
|
||||||
ret = ArrowBlockAccessor._empty_table()
|
ret = ArrowBlockAccessor._empty_table()
|
||||||
else:
|
else:
|
||||||
concat_and_sort = get_concat_and_sort_transform(
|
ret = pyarrow.concat_tables(blocks, promote=True)
|
||||||
DatasetContext.get_current()
|
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
|
||||||
)
|
ret = ret.take(indices)
|
||||||
ret = concat_and_sort(blocks, key, _descending)
|
|
||||||
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())
|
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
from typing import List, TYPE_CHECKING
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pyarrow
|
|
||||||
except ImportError:
|
|
||||||
pyarrow = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import polars as pl
|
|
||||||
except ImportError:
|
|
||||||
pl = None
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ray.data.impl.sort import SortKeyT
|
|
||||||
|
|
||||||
|
|
||||||
def check_polars_installed():
|
|
||||||
if pl is None:
|
|
||||||
raise ImportError(
|
|
||||||
"polars not installed. Install with `pip install polars` or set "
|
|
||||||
"`DatasetContext.use_polars = False` to fall back to pyarrow"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table":
|
|
||||||
check_polars_installed()
|
|
||||||
col, _ = key[0]
|
|
||||||
df = pl.from_arrow(table)
|
|
||||||
return df.sort(col, reverse=descending).to_arrow()
|
|
||||||
|
|
||||||
|
|
||||||
def concat_and_sort(
|
|
||||||
blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool
|
|
||||||
) -> "pyarrow.Table":
|
|
||||||
check_polars_installed()
|
|
||||||
col, _ = key[0]
|
|
||||||
blocks = [pl.from_arrow(block) for block in blocks]
|
|
||||||
df = pl.concat(blocks).sort(col, reverse=descending)
|
|
||||||
return df.to_arrow()
|
|
|
@ -1,24 +0,0 @@
|
||||||
from typing import List, TYPE_CHECKING
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pyarrow
|
|
||||||
except ImportError:
|
|
||||||
pyarrow = None
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ray.data.impl.sort import SortKeyT
|
|
||||||
|
|
||||||
|
|
||||||
def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table":
|
|
||||||
import pyarrow.compute as pac
|
|
||||||
|
|
||||||
indices = pac.sort_indices(table, sort_keys=key)
|
|
||||||
return table.take(indices)
|
|
||||||
|
|
||||||
|
|
||||||
def concat_and_sort(
|
|
||||||
blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool
|
|
||||||
) -> "pyarrow.Table":
|
|
||||||
ret = pyarrow.concat_tables(blocks, promote=True)
|
|
||||||
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
|
|
||||||
return ret.take(indices)
|
|
|
@ -81,7 +81,6 @@ class OnesSource(Datasource):
|
||||||
return read_tasks
|
return read_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="failing after #24523")
|
|
||||||
@pytest.mark.parametrize("lazy_input", [True, False])
|
@pytest.mark.parametrize("lazy_input", [True, False])
|
||||||
def test_memory_release_pipeline(shutdown_only, lazy_input):
|
def test_memory_release_pipeline(shutdown_only, lazy_input):
|
||||||
context = DatasetContext.get_current()
|
context = DatasetContext.get_current()
|
||||||
|
|
|
@ -76,17 +76,12 @@ def test_sort_partition_same_key_to_same_block(
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_items,parallelism", [(100, 1), (1000, 4)])
|
@pytest.mark.parametrize("num_items,parallelism", [(100, 1), (1000, 4)])
|
||||||
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
|
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
|
||||||
@pytest.mark.parametrize("use_polars", [False, True])
|
def test_sort_arrow(ray_start_regular, num_items, parallelism, use_push_based_shuffle):
|
||||||
def test_sort_arrow(
|
|
||||||
ray_start_regular, num_items, parallelism, use_push_based_shuffle, use_polars
|
|
||||||
):
|
|
||||||
ctx = ray.data.context.DatasetContext.get_current()
|
ctx = ray.data.context.DatasetContext.get_current()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
original_push_based_shuffle = ctx.use_push_based_shuffle
|
original = ctx.use_push_based_shuffle
|
||||||
ctx.use_push_based_shuffle = use_push_based_shuffle
|
ctx.use_push_based_shuffle = use_push_based_shuffle
|
||||||
original_use_polars = ctx.use_polars
|
|
||||||
ctx.use_polars = use_polars
|
|
||||||
|
|
||||||
a = list(reversed(range(num_items)))
|
a = list(reversed(range(num_items)))
|
||||||
b = [f"{x:03}" for x in range(num_items)]
|
b = [f"{x:03}" for x in range(num_items)]
|
||||||
|
@ -117,22 +112,16 @@ def test_sort_arrow(
|
||||||
assert_sorted(ds.sort(key="b"), zip(a, b))
|
assert_sorted(ds.sort(key="b"), zip(a, b))
|
||||||
assert_sorted(ds.sort(key="a", descending=True), zip(a, b))
|
assert_sorted(ds.sort(key="a", descending=True), zip(a, b))
|
||||||
finally:
|
finally:
|
||||||
ctx.use_push_based_shuffle = original_push_based_shuffle
|
ctx.use_push_based_shuffle = original
|
||||||
ctx.use_polars = original_use_polars
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
|
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
|
||||||
@pytest.mark.parametrize("use_polars", [False, True])
|
def test_sort_arrow_with_empty_blocks(ray_start_regular, use_push_based_shuffle):
|
||||||
def test_sort_arrow_with_empty_blocks(
|
|
||||||
ray_start_regular, use_push_based_shuffle, use_polars
|
|
||||||
):
|
|
||||||
ctx = ray.data.context.DatasetContext.get_current()
|
ctx = ray.data.context.DatasetContext.get_current()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
original_push_based_shuffle = ctx.use_push_based_shuffle
|
original = ctx.use_push_based_shuffle
|
||||||
ctx.use_push_based_shuffle = use_push_based_shuffle
|
ctx.use_push_based_shuffle = use_push_based_shuffle
|
||||||
original_use_polars = ctx.use_polars
|
|
||||||
ctx.use_polars = use_polars
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
BlockAccessor.for_block(pa.Table.from_pydict({})).sample(10, "A").num_rows
|
BlockAccessor.for_block(pa.Table.from_pydict({})).sample(10, "A").num_rows
|
||||||
|
@ -173,8 +162,7 @@ def test_sort_arrow_with_empty_blocks(
|
||||||
)
|
)
|
||||||
assert ds.sort("value").count() == 0
|
assert ds.sort("value").count() == 0
|
||||||
finally:
|
finally:
|
||||||
ctx.use_push_based_shuffle = original_push_based_shuffle
|
ctx.use_push_based_shuffle = original
|
||||||
ctx.use_polars = original_use_polars
|
|
||||||
|
|
||||||
|
|
||||||
def test_push_based_shuffle_schedule():
|
def test_push_based_shuffle_schedule():
|
||||||
|
|
|
@ -45,8 +45,6 @@ aiorwlock
|
||||||
|
|
||||||
# Requirements for running tests
|
# Requirements for running tests
|
||||||
pyarrow >= 6.0.1, < 7.0.0
|
pyarrow >= 6.0.1, < 7.0.0
|
||||||
# Used for Dataset tests.
|
|
||||||
polars
|
|
||||||
azure-cli-core==2.29.1
|
azure-cli-core==2.29.1
|
||||||
azure-identity==1.7.0
|
azure-identity==1.7.0
|
||||||
azure-mgmt-compute==23.1.0
|
azure-mgmt-compute==23.1.0
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import List
|
||||||
from ray.data.impl.arrow_block import ArrowRow
|
from ray.data.impl.arrow_block import ArrowRow
|
||||||
from ray.data.impl.util import _check_pyarrow_version
|
from ray.data.impl.util import _check_pyarrow_version
|
||||||
from ray.data.block import Block, BlockMetadata
|
from ray.data.block import Block, BlockMetadata
|
||||||
from ray.data.context import DatasetContext
|
|
||||||
|
|
||||||
from ray.data.datasource import Datasource, ReadTask
|
from ray.data.datasource import Datasource, ReadTask
|
||||||
from ray.internal.internal_api import memory_summary
|
from ray.internal.internal_api import memory_summary
|
||||||
|
@ -86,15 +85,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--shuffle", help="shuffle instead of sort", action="store_true"
|
"--shuffle", help="shuffle instead of sort", action="store_true"
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-polars", action="store_true")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.use_polars and not args.shuffle:
|
|
||||||
print("Using polars for sort")
|
|
||||||
ctx = DatasetContext.get_current()
|
|
||||||
ctx.use_polars = True
|
|
||||||
|
|
||||||
num_partitions = int(args.num_partitions)
|
num_partitions = int(args.num_partitions)
|
||||||
partition_size = int(float(args.partition_size))
|
partition_size = int(float(args.partition_size))
|
||||||
print(
|
print(
|
||||||
|
|
Loading…
Add table
Reference in a new issue