[data] Add take_all() and raise error if to_pandas() drops records (#19619)

This commit is contained in:
Eric Liang 2021-10-21 22:23:50 -07:00 committed by GitHub
parent 59b2f1f3f2
commit 50e305e799
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 12 deletions

View file

@ -944,6 +944,26 @@ class Dataset(Generic[T]):
break
return output
def take_all(self, limit: int = 100000) -> List[T]:
"""Take all the records in the dataset.
Time complexity: O(dataset size)
Args:
limit: Raise an error if the size exceeds the specified limit.
Returns:
A list of all the records in the dataset.
"""
output = []
for row in self.iter_rows():
output.append(row)
if len(output) > limit:
raise ValueError(
"The dataset has more than the given limit of {} records.".
format(limit))
return output
def show(self, limit: int = 20) -> None:
"""Print up to the given number of records from the dataset.
@ -1624,16 +1644,19 @@ class Dataset(Generic[T]):
return raydp.spark.ray_dataset_to_spark_dataframe(
spark, self.schema(), self.get_internal_block_refs(), locations)
def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame":
def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame":
"""Convert this dataset into a single Pandas DataFrame.
This is only supported for datasets convertible to Arrow records. This
limits the number of records returned to the provided limit.
This is only supported for datasets convertible to Arrow records. An
error is raised if the number of records exceeds the provided limit.
Note that you can use ``.limit()`` on the dataset beforehand to
truncate the dataset manually.
Time complexity: O(limit)
Time complexity: O(dataset size)
Args:
limit: The maximum number of records to return.
limit: The maximum number of records to return. An error will be
raised if the limit is exceeded.
Returns:
A Pandas DataFrame created from this dataset, containing a limited
@ -1641,10 +1664,10 @@ class Dataset(Generic[T]):
"""
if self.count() > limit:
logger.warning(f"Only returning the first {limit} records from "
"to_pandas()")
limited_ds = self.limit(limit)
blocks = limited_ds.get_internal_block_refs()
raise ValueError(
"The dataset has more than the given limit of {} records.".
format(limit))
blocks = self.get_internal_block_refs()
output = DelegatingArrowBlockBuilder()
for block in ray.get(blocks):
output.add_block(block)

View file

@ -25,7 +25,9 @@ PER_DATASET_OUTPUT_OPS = [
]
# Operations that operate over the stream of output batches from the pipeline.
OUTPUT_ITER_OPS = ["take", "show", "iter_rows", "to_tf", "to_torch"]
OUTPUT_ITER_OPS = [
"take", "take_all", "show", "iter_rows", "to_tf", "to_torch"
]
@PublicAPI(stability="beta")

View file

@ -1138,14 +1138,21 @@ def test_to_pandas(ray_start_regular_shared):
assert df.equals(dfds)
# Test limit.
with pytest.raises(ValueError):
dfds = ds.to_pandas(limit=3)
assert df[:3].equals(dfds)
# Test limit greater than number of rows.
dfds = ds.to_pandas(limit=6)
assert df.equals(dfds)
def test_take_all(ray_start_regular_shared):
assert ray.data.range(5).take_all() == [0, 1, 2, 3, 4]
with pytest.raises(ValueError):
assert ray.data.range(5).take_all(4)
def test_to_pandas_refs(ray_start_regular_shared):
n = 5
df = pd.DataFrame({"value": list(range(n))})