mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[data] Add take_all() and raise error if to_pandas() drops records (#19619)
This commit is contained in:
parent
59b2f1f3f2
commit
50e305e799
3 changed files with 44 additions and 12 deletions
|
@ -944,6 +944,26 @@ class Dataset(Generic[T]):
|
||||||
break
|
break
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def take_all(self, limit: int = 100000) -> List[T]:
|
||||||
|
"""Take all the records in the dataset.
|
||||||
|
|
||||||
|
Time complexity: O(dataset size)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Raise an error if the size exceeds the specified limit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all the records in the dataset.
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
for row in self.iter_rows():
|
||||||
|
output.append(row)
|
||||||
|
if len(output) > limit:
|
||||||
|
raise ValueError(
|
||||||
|
"The dataset has more than the given limit of {} records.".
|
||||||
|
format(limit))
|
||||||
|
return output
|
||||||
|
|
||||||
def show(self, limit: int = 20) -> None:
|
def show(self, limit: int = 20) -> None:
|
||||||
"""Print up to the given number of records from the dataset.
|
"""Print up to the given number of records from the dataset.
|
||||||
|
|
||||||
|
@ -1624,16 +1644,19 @@ class Dataset(Generic[T]):
|
||||||
return raydp.spark.ray_dataset_to_spark_dataframe(
|
return raydp.spark.ray_dataset_to_spark_dataframe(
|
||||||
spark, self.schema(), self.get_internal_block_refs(), locations)
|
spark, self.schema(), self.get_internal_block_refs(), locations)
|
||||||
|
|
||||||
def to_pandas(self, limit: int = 1000) -> "pandas.DataFrame":
|
def to_pandas(self, limit: int = 100000) -> "pandas.DataFrame":
|
||||||
"""Convert this dataset into a single Pandas DataFrame.
|
"""Convert this dataset into a single Pandas DataFrame.
|
||||||
|
|
||||||
This is only supported for datasets convertible to Arrow records. This
|
This is only supported for datasets convertible to Arrow records. An
|
||||||
limits the number of records returned to the provided limit.
|
error is raised if the number of records exceeds the provided limit.
|
||||||
|
Note that you can use ``.limit()`` on the dataset beforehand to
|
||||||
|
truncate the dataset manually.
|
||||||
|
|
||||||
Time complexity: O(limit)
|
Time complexity: O(dataset size)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
limit: The maximum number of records to return.
|
limit: The maximum number of records to return. An error will be
|
||||||
|
raised if the limit is exceeded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Pandas DataFrame created from this dataset, containing a limited
|
A Pandas DataFrame created from this dataset, containing a limited
|
||||||
|
@ -1641,10 +1664,10 @@ class Dataset(Generic[T]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.count() > limit:
|
if self.count() > limit:
|
||||||
logger.warning(f"Only returning the first {limit} records from "
|
raise ValueError(
|
||||||
"to_pandas()")
|
"The dataset has more than the given limit of {} records.".
|
||||||
limited_ds = self.limit(limit)
|
format(limit))
|
||||||
blocks = limited_ds.get_internal_block_refs()
|
blocks = self.get_internal_block_refs()
|
||||||
output = DelegatingArrowBlockBuilder()
|
output = DelegatingArrowBlockBuilder()
|
||||||
for block in ray.get(blocks):
|
for block in ray.get(blocks):
|
||||||
output.add_block(block)
|
output.add_block(block)
|
||||||
|
|
|
@ -25,7 +25,9 @@ PER_DATASET_OUTPUT_OPS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations that operate over the stream of output batches from the pipeline.
|
# Operations that operate over the stream of output batches from the pipeline.
|
||||||
OUTPUT_ITER_OPS = ["take", "show", "iter_rows", "to_tf", "to_torch"]
|
OUTPUT_ITER_OPS = [
|
||||||
|
"take", "take_all", "show", "iter_rows", "to_tf", "to_torch"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
|
|
|
@ -1138,14 +1138,21 @@ def test_to_pandas(ray_start_regular_shared):
|
||||||
assert df.equals(dfds)
|
assert df.equals(dfds)
|
||||||
|
|
||||||
# Test limit.
|
# Test limit.
|
||||||
dfds = ds.to_pandas(limit=3)
|
with pytest.raises(ValueError):
|
||||||
assert df[:3].equals(dfds)
|
dfds = ds.to_pandas(limit=3)
|
||||||
|
|
||||||
# Test limit greater than number of rows.
|
# Test limit greater than number of rows.
|
||||||
dfds = ds.to_pandas(limit=6)
|
dfds = ds.to_pandas(limit=6)
|
||||||
assert df.equals(dfds)
|
assert df.equals(dfds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_take_all(ray_start_regular_shared):
|
||||||
|
assert ray.data.range(5).take_all() == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert ray.data.range(5).take_all(4)
|
||||||
|
|
||||||
|
|
||||||
def test_to_pandas_refs(ray_start_regular_shared):
|
def test_to_pandas_refs(ray_start_regular_shared):
|
||||||
n = 5
|
n = 5
|
||||||
df = pd.DataFrame({"value": list(range(n))})
|
df = pd.DataFrame({"value": list(range(n))})
|
||||||
|
|
Loading…
Add table
Reference in a new issue