From 50e305e7990c11166d93402dc5d49fe86180a71b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 21 Oct 2021 22:23:50 -0700 Subject: [PATCH] [data] Add take_all() and raise error if to_pandas() drops records (#19619) --- python/ray/data/dataset.py | 41 +++++++++++++++++++++------ python/ray/data/dataset_pipeline.py | 4 ++- python/ray/data/tests/test_dataset.py | 11 +++++-- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 040671d50..9ead94809 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -944,6 +944,26 @@ class Dataset(Generic[T]): break return output + def take_all(self, limit: int = 100000) -> List[T]: + """Take all the records in the dataset. + + Time complexity: O(dataset size) + + Args: + limit: Raise an error if the size exceeds the specified limit. + + Returns: + A list of all the records in the dataset. + """ + output = [] + for row in self.iter_rows(): + output.append(row) + if len(output) > limit: + raise ValueError( + "The dataset has more than the given limit of {} records.". + format(limit)) + return output + def show(self, limit: int = 20) -> None: """Print up to the given number of records from the dataset. @@ -1624,16 +1644,19 @@ class Dataset(Generic[T]): return raydp.spark.ray_dataset_to_spark_dataframe( spark, self.schema(), self.get_internal_block_refs(), locations) - def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame": + def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame": """Convert this dataset into a single Pandas DataFrame. - This is only supported for datasets convertible to Arrow records. This - limits the number of records returned to the provided limit. + This is only supported for datasets convertible to Arrow records. An + error is raised if the number of records exceeds the provided limit. + Note that you can use ``.limit()`` on the dataset beforehand to + truncate the dataset manually. - Time complexity: O(limit) + Time complexity: O(dataset size) Args: - limit: The maximum number of records to return. + limit: The maximum number of records to return. An error will be + raised if the limit is exceeded. Returns: A Pandas DataFrame created from this dataset, containing a limited @@ -1641,10 +1664,10 @@ class Dataset(Generic[T]): """ if self.count() > limit: - logger.warning(f"Only returning the first {limit} records from " - "to_pandas()") - limited_ds = self.limit(limit) - blocks = limited_ds.get_internal_block_refs() + raise ValueError( + "The dataset has more than the given limit of {} records.". + format(limit)) + blocks = self.get_internal_block_refs() output = DelegatingArrowBlockBuilder() for block in ray.get(blocks): output.add_block(block) diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index ecd31e519..d397f5e07 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -25,7 +25,9 @@ PER_DATASET_OUTPUT_OPS = [ ] # Operations that operate over the stream of output batches from the pipeline. -OUTPUT_ITER_OPS = ["take", "show", "iter_rows", "to_tf", "to_torch"] +OUTPUT_ITER_OPS = [ + "take", "take_all", "show", "iter_rows", "to_tf", "to_torch" +] @PublicAPI(stability="beta") diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index f46decf34..c3e0238c6 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1138,14 +1138,21 @@ def test_to_pandas(ray_start_regular_shared): assert df.equals(dfds) # Test limit. - dfds = ds.to_pandas(limit=3) - assert df[:3].equals(dfds) + with pytest.raises(ValueError): + dfds = ds.to_pandas(limit=3) # Test limit greater than number of rows. dfds = ds.to_pandas(limit=6) assert df.equals(dfds) +def test_take_all(ray_start_regular_shared): + assert ray.data.range(5).take_all() == [0, 1, 2, 3, 4] + + with pytest.raises(ValueError): + assert ray.data.range(5).take_all(4) + + def test_to_pandas_refs(ray_start_regular_shared): n = 5 df = pd.DataFrame({"value": list(range(n))})