[data] Add take_all() and raise error if to_pandas() drops records (#19619)

This commit is contained in:
Eric Liang 2021-10-21 22:23:50 -07:00 committed by GitHub
parent 59b2f1f3f2
commit 50e305e799
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 12 deletions

View file

@ -944,6 +944,26 @@ class Dataset(Generic[T]):
break break
return output return output
def take_all(self, limit: int = 100000) -> List[T]:
"""Take all the records in the dataset.
Time complexity: O(dataset size)
Args:
limit: Raise an error if the size exceeds the specified limit.
Returns:
A list of all the records in the dataset.
"""
output = []
for row in self.iter_rows():
output.append(row)
if len(output) > limit:
raise ValueError(
"The dataset has more than the given limit of {} records.".
format(limit))
return output
def show(self, limit: int = 20) -> None: def show(self, limit: int = 20) -> None:
"""Print up to the given number of records from the dataset. """Print up to the given number of records from the dataset.
@ -1624,16 +1644,19 @@ class Dataset(Generic[T]):
return raydp.spark.ray_dataset_to_spark_dataframe( return raydp.spark.ray_dataset_to_spark_dataframe(
spark, self.schema(), self.get_internal_block_refs(), locations) spark, self.schema(), self.get_internal_block_refs(), locations)
def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame": def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame":
"""Convert this dataset into a single Pandas DataFrame. """Convert this dataset into a single Pandas DataFrame.
This is only supported for datasets convertible to Arrow records. This This is only supported for datasets convertible to Arrow records. An
limits the number of records returned to the provided limit. error is raised if the number of records exceeds the provided limit.
Note that you can use ``.limit()`` on the dataset beforehand to
truncate the dataset manually.
Time complexity: O(limit) Time complexity: O(dataset size)
Args: Args:
limit: The maximum number of records to return. limit: The maximum number of records to return. An error will be
raised if the limit is exceeded.
Returns: Returns:
A Pandas DataFrame created from this dataset, containing a limited A Pandas DataFrame created from this dataset, containing a limited
@ -1641,10 +1664,10 @@ class Dataset(Generic[T]):
""" """
if self.count() > limit: if self.count() > limit:
logger.warning(f"Only returning the first {limit} records from " raise ValueError(
"to_pandas()") "The dataset has more than the given limit of {} records.".
limited_ds = self.limit(limit) format(limit))
blocks = limited_ds.get_internal_block_refs() blocks = self.get_internal_block_refs()
output = DelegatingArrowBlockBuilder() output = DelegatingArrowBlockBuilder()
for block in ray.get(blocks): for block in ray.get(blocks):
output.add_block(block) output.add_block(block)

View file

@ -25,7 +25,9 @@ PER_DATASET_OUTPUT_OPS = [
] ]
# Operations that operate over the stream of output batches from the pipeline. # Operations that operate over the stream of output batches from the pipeline.
OUTPUT_ITER_OPS = ["take", "show", "iter_rows", "to_tf", "to_torch"] OUTPUT_ITER_OPS = [
"take", "take_all", "show", "iter_rows", "to_tf", "to_torch"
]
@PublicAPI(stability="beta") @PublicAPI(stability="beta")

View file

@ -1138,14 +1138,21 @@ def test_to_pandas(ray_start_regular_shared):
assert df.equals(dfds) assert df.equals(dfds)
# Test limit. # Test limit.
dfds = ds.to_pandas(limit=3) with pytest.raises(ValueError):
assert df[:3].equals(dfds) dfds = ds.to_pandas(limit=3)
# Test limit greater than number of rows. # Test limit greater than number of rows.
dfds = ds.to_pandas(limit=6) dfds = ds.to_pandas(limit=6)
assert df.equals(dfds) assert df.equals(dfds)
def test_take_all(ray_start_regular_shared):
assert ray.data.range(5).take_all() == [0, 1, 2, 3, 4]
with pytest.raises(ValueError):
assert ray.data.range(5).take_all(4)
def test_to_pandas_refs(ray_start_regular_shared): def test_to_pandas_refs(ray_start_regular_shared):
n = 5 n = 5
df = pd.DataFrame({"value": list(range(n))}) df = pd.DataFrame({"value": list(range(n))})