[Datasets] Support drop_columns API (#26200)

This commit is contained in:
Cheng Su 2022-07-03 14:41:54 -07:00 committed by GitHub
parent 7360452d2a
commit 11a24d6ef1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 1 deletions

View file

@ -532,6 +532,42 @@ class Dataset(Generic[T]):
process_batch, batch_format="pandas", compute=compute, **ray_remote_args
)
def drop_columns(
self,
cols: List[str],
*,
compute: Optional[str] = None,
**ray_remote_args,
) -> "Dataset[U]":
"""Drop one or more columns from the dataset.
This is a blocking operation.
Examples:
>>> import ray
>>> ds = ray.data.range_table(100) # doctest: +SKIP
>>> # Add a new column equal to value * 2.
>>> ds = ds.add_column( # doctest: +SKIP
... "new_col", lambda df: df["value"] * 2)
>>> # Drop the existing "value" column.
>>> ds = ds.drop_columns(["value"]) # doctest: +SKIP
Time complexity: O(dataset size / parallelism)
Args:
cols: Names of the columns to drop. If any name does not exist,
an exception will be raised.
compute: The compute strategy, either "tasks" (default) to use Ray
tasks, or ActorPoolStrategy(min, max) to use an autoscaling actor pool.
ray_remote_args: Additional resource requirements to request from
ray (e.g., num_gpus=1 to request GPUs for the map tasks
"""
return self.map_batches(
lambda batch: batch.drop(columns=cols), compute=compute, **ray_remote_args
)
def flat_map(
self,
fn: RowUDF,

View file

@ -37,7 +37,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Operations that can be naively applied per dataset row in the pipeline.
_PER_DATASET_OPS = ["map", "map_batches", "add_column", "flat_map", "filter"]
_PER_DATASET_OPS = [
"map",
"map_batches",
"add_column",
"drop_columns",
"flat_map",
"filter",
]
# Operations that apply to each dataset holistically in the pipeline.
_HOLISTIC_PER_DATASET_OPS = [

View file

@ -1739,6 +1739,25 @@ def test_add_column(ray_start_regular_shared):
ds = ray.data.range(5).add_column("value", 0)
def test_drop_columns(ray_start_regular_shared, tmp_path):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]})
ds1 = ray.data.from_pandas(df)
ds1.write_parquet(str(tmp_path))
ds2 = ray.data.read_parquet(str(tmp_path))
for ds in [ds1, ds2]:
assert ds.drop_columns(["col2"]).take(1) == [{"col1": 1, "col3": 3}]
assert ds.drop_columns(["col1", "col3"]).take(1) == [{"col2": 2}]
assert ds.drop_columns([]).take(1) == [{"col1": 1, "col2": 2, "col3": 3}]
assert ds.drop_columns(["col1", "col2", "col3"]).take(1) == [{}]
assert ds.drop_columns(["col1", "col1", "col2", "col1"]).take(1) == [
{"col3": 3}
]
# Test dropping non-existent column
with pytest.raises(KeyError):
ds.drop_columns(["dummy_col", "col1", "col2"])
def test_map_batches_basic(ray_start_regular_shared, tmp_path):
# Test input validation
ds = ray.data.range(5)

View file

@ -507,6 +507,13 @@ def test_preserve_whether_base_datasets_can_be_cleared(ray_start_regular_shared)
assert not p2._base_datasets_can_be_cleared
def test_drop_columns(ray_start_regular_shared):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [2, 3, 4], "col3": [3, 4, 5]})
ds = ray.data.from_pandas(df)
pipe = ds.repeat()
assert pipe.drop_columns(["col2"]).take(1) == [{"col1": 1, "col3": 3}]
if __name__ == "__main__":
import sys