[Datasets] to_torch implementation (#17113)

This commit is contained in:
Amog Kamsetty 2021-07-15 13:02:07 -07:00 committed by GitHub
parent bdaa96bf43
commit 6ff4d1ddb1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 177 additions and 3 deletions

View file

@ -859,18 +859,86 @@ class Dataset(Generic[T]):
if batcher.has_any() and not drop_last:
yield format_batch(batcher.next_batch(), batch_format)
def to_torch(self, **todo) -> "torch.utils.data.IterableDataset":
"""Return a Torch data iterator over this dataset.
def to_torch(self,
label_column: str,
feature_columns: Optional[List[str]] = None,
label_column_dtype: Optional["torch.dtype"] = None,
feature_column_dtypes: Optional[List["torch.dtype"]] = None,
prefetch_blocks: int = 0) -> \
"torch.utils.data.IterableDataset":
"""Return a Torch IterableDataset over this dataset.
Each element in IterableDataset will be a list consisting of 2
elements. The first item is a list of the feature tensors. The
second item is the label tensor. Each tensor will be of shape (N,
1), where N is the ``batch_size`` used by the DataLoader.
Note that you probably want to call ``.split()`` on this dataset if
there are to be multiple Torch workers consuming the data.
Time complexity: O(1)
Args:
label_column (str): The name of the column used as the label
(second element of the output list).
feature_columns (Optional[List[str]]): The names of the columns
to use as the features. If None, then use all columns
except the label columns as the features.
label_column_dtype (Optional[torch.dtype]): The torch dtype to
use for the label column. If None, then automatically infer
the dtype.
feature_column_dtypes (Optional[List[torch.dtype]]): The dtypes
to use for the feature columns. The len of this list must
be equal to the len of ``feature_columns``. If None,
then automatically infer the dtype.
prefetch_blocks (int): The number of blocks to prefetch ahead of
the current block during the scan.
Returns:
A torch IterableDataset.
"""
raise NotImplementedError # P1
import torch
from ray.experimental.data.impl.torch_iterable_dataset import \
TorchIterableDataset
if feature_columns and feature_column_dtypes:
if len(feature_columns) != len(feature_column_dtypes):
raise ValueError("The lengths of `feature_columns` "
f"({len(feature_columns)}) and "
f"`feature_column_dtypes` ("
f"{len(feature_column_dtypes)}) do not "
"match!")
def make_generator():
for batch in self.iter_batches(prefetch_blocks=prefetch_blocks):
label_vals = batch.pop(label_column).values
label_tensor = torch.as_tensor(
label_vals, dtype=label_column_dtype)
label_tensor = label_tensor.view(-1, 1)
feature_tensor = []
if feature_columns:
batch = batch[feature_columns]
if feature_column_dtypes:
dtypes = feature_column_dtypes
else:
dtypes = [None] * len(batch.columns)
for col, dtype in zip(batch.columns, dtypes):
col_vals = batch[col].values
t = torch.as_tensor(col_vals, dtype=dtype)
t = t.view(-1, 1)
feature_tensor.append(t)
num_rows = batch.shape[0]
for i in range(num_rows):
features = [tensor[i] for tensor in feature_tensor]
label = label_tensor[i]
yield (features, label)
return TorchIterableDataset(make_generator)
def to_tf(self,
label_column: str,

View file

@ -0,0 +1,21 @@
import torch
from torch.utils.data import IterableDataset
class TorchIterableDataset(IterableDataset):
def __init__(self, generator_func):
self.generator_func = generator_func
def __iter__(self):
it = self.generator_func()
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
yield from it
else:
# Multiple workers are doing dataloading.
# Each worker has a copy of the data.
# Avoid duplicates.
import itertools
it = itertools.islice(it, worker_info.id, None,
worker_info.num_workers)
yield from it

View file

@ -11,6 +11,9 @@ import pyarrow as pa
import pyarrow.parquet as pq
import pytest
import tensorflow as tf
import torch
from torch.utils.data import DataLoader
import ray
from ray.util.dask import ray_dask_get
@ -755,6 +758,88 @@ def test_to_tf_feature_columns(ray_start_regular_shared):
assert np.array_equal(df.values, combined_iterations)
def test_to_torch(ray_start_regular_shared):
df1 = pd.DataFrame({
"one": [1, 2, 3],
"two": [1.0, 2.0, 3.0],
"label": [1.0, 2.0, 3.0]
})
df2 = pd.DataFrame({
"one": [4, 5, 6],
"two": [4.0, 5.0, 6.0],
"label": [4.0, 5.0, 6.0]
})
df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]})
df = pd.concat([df1, df2, df3])
ds = ray.experimental.data.from_pandas(
[ray.put(df1), ray.put(df2), ray.put(df3)])
torchd = ds.to_torch(label_column="label")
dataloader = DataLoader(torchd, batch_size=3)
num_epochs = 2
for _ in range(num_epochs):
iterations = []
for batch in iter(dataloader):
iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy())
combined_iterations = np.concatenate(iterations)
assert np.array_equal(np.sort(df.values), np.sort(combined_iterations))
def test_to_torch_multiple_workers(ray_start_regular_shared):
df1 = pd.DataFrame({
"one": [1, 2, 3],
"two": [1.0, 2.0, 3.0],
"label": [1.0, 2.0, 3.0]
})
df2 = pd.DataFrame({
"one": [4, 5, 6],
"two": [4.0, 5.0, 6.0],
"label": [4.0, 5.0, 6.0]
})
df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]})
df = pd.concat([df1, df2, df3])
ds = ray.experimental.data.from_pandas(
[ray.put(df1), ray.put(df2), ray.put(df3)])
torchd = ds.to_torch(label_column="label")
dataloader = DataLoader(torchd, batch_size=1, num_workers=2)
iterations = []
for batch in iter(dataloader):
numpy_batch = torch.cat((*batch[0], batch[1]), axis=1).numpy()
assert np.all(np.isin(numpy_batch, df.values))
iterations.append(numpy_batch)
assert len(iterations) == len(df.values)
def test_to_torch_feature_columns(ray_start_regular_shared):
df1 = pd.DataFrame({
"one": [1, 2, 3],
"two": [1.0, 2.0, 3.0],
"label": [1.0, 2.0, 3.0]
})
df2 = pd.DataFrame({
"one": [4, 5, 6],
"two": [4.0, 5.0, 6.0],
"label": [4.0, 5.0, 6.0]
})
df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]})
df = pd.concat([df1, df2, df3]).drop("two", axis=1)
ds = ray.experimental.data.from_pandas(
[ray.put(df1), ray.put(df2), ray.put(df3)])
torchd = ds.to_torch("label", feature_columns=["one"])
iterations = []
dataloader = DataLoader(torchd, batch_size=3)
for batch in iter(dataloader):
iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy())
combined_iterations = np.concatenate(iterations)
assert np.array_equal(df.values, combined_iterations)
def test_json_read(ray_start_regular_shared, tmp_path):
# Single file.
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})