mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[data] Add take_all() and raise error if to_pandas() drops records (#19619)
This commit is contained in:
parent
59b2f1f3f2
commit
50e305e799
3 changed files with 44 additions and 12 deletions
|
@ -944,6 +944,26 @@ class Dataset(Generic[T]):
|
|||
break
|
||||
return output
|
||||
|
||||
def take_all(self, limit: int = 100000) -> List[T]:
|
||||
"""Take all the records in the dataset.
|
||||
|
||||
Time complexity: O(dataset size)
|
||||
|
||||
Args:
|
||||
limit: Raise an error if the size exceeds the specified limit.
|
||||
|
||||
Returns:
|
||||
A list of all the records in the dataset.
|
||||
"""
|
||||
output = []
|
||||
for row in self.iter_rows():
|
||||
output.append(row)
|
||||
if len(output) > limit:
|
||||
raise ValueError(
|
||||
"The dataset has more than the given limit of {} records.".
|
||||
format(limit))
|
||||
return output
|
||||
|
||||
def show(self, limit: int = 20) -> None:
|
||||
"""Print up to the given number of records from the dataset.
|
||||
|
||||
|
@ -1624,16 +1644,19 @@ class Dataset(Generic[T]):
|
|||
return raydp.spark.ray_dataset_to_spark_dataframe(
|
||||
spark, self.schema(), self.get_internal_block_refs(), locations)
|
||||
|
||||
def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame":
|
||||
def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame":
|
||||
"""Convert this dataset into a single Pandas DataFrame.
|
||||
|
||||
This is only supported for datasets convertible to Arrow records. This
|
||||
limits the number of records returned to the provided limit.
|
||||
This is only supported for datasets convertible to Arrow records. An
|
||||
error is raised if the number of records exceeds the provided limit.
|
||||
Note that you can use ``.limit()`` on the dataset beforehand to
|
||||
truncate the dataset manually.
|
||||
|
||||
Time complexity: O(limit)
|
||||
Time complexity: O(dataset size)
|
||||
|
||||
Args:
|
||||
limit: The maximum number of records to return.
|
||||
limit: The maximum number of records to return. An error will be
|
||||
raised if the limit is exceeded.
|
||||
|
||||
Returns:
|
||||
A Pandas DataFrame created from this dataset, containing a limited
|
||||
|
@ -1641,10 +1664,10 @@ class Dataset(Generic[T]):
|
|||
"""
|
||||
|
||||
if self.count() > limit:
|
||||
logger.warning(f"Only returning the first {limit} records from "
|
||||
"to_pandas()")
|
||||
limited_ds = self.limit(limit)
|
||||
blocks = limited_ds.get_internal_block_refs()
|
||||
raise ValueError(
|
||||
"The dataset has more than the given limit of {} records.".
|
||||
format(limit))
|
||||
blocks = self.get_internal_block_refs()
|
||||
output = DelegatingArrowBlockBuilder()
|
||||
for block in ray.get(blocks):
|
||||
output.add_block(block)
|
||||
|
|
|
@ -25,7 +25,9 @@ PER_DATASET_OUTPUT_OPS = [
|
|||
]
|
||||
|
||||
# Operations that operate over the stream of output batches from the pipeline.
|
||||
OUTPUT_ITER_OPS = ["take", "show", "iter_rows", "to_tf", "to_torch"]
|
||||
OUTPUT_ITER_OPS = [
|
||||
"take", "take_all", "show", "iter_rows", "to_tf", "to_torch"
|
||||
]
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
|
|
@ -1138,14 +1138,21 @@ def test_to_pandas(ray_start_regular_shared):
|
|||
assert df.equals(dfds)
|
||||
|
||||
# Test limit.
|
||||
dfds = ds.to_pandas(limit=3)
|
||||
assert df[:3].equals(dfds)
|
||||
with pytest.raises(ValueError):
|
||||
dfds = ds.to_pandas(limit=3)
|
||||
|
||||
# Test limit greater than number of rows.
|
||||
dfds = ds.to_pandas(limit=6)
|
||||
assert df.equals(dfds)
|
||||
|
||||
|
||||
def test_take_all(ray_start_regular_shared):
|
||||
assert ray.data.range(5).take_all() == [0, 1, 2, 3, 4]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert ray.data.range(5).take_all(4)
|
||||
|
||||
|
||||
def test_to_pandas_refs(ray_start_regular_shared):
|
||||
n = 5
|
||||
df = pd.DataFrame({"value": list(range(n))})
|
||||
|
|
Loading…
Add table
Reference in a new issue