mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Datasets] to_torch
implementation (#17113)
This commit is contained in:
parent
bdaa96bf43
commit
6ff4d1ddb1
3 changed files with 177 additions and 3 deletions
|
@ -859,18 +859,86 @@ class Dataset(Generic[T]):
|
|||
if batcher.has_any() and not drop_last:
|
||||
yield format_batch(batcher.next_batch(), batch_format)
|
||||
|
||||
def to_torch(self, **todo) -> "torch.utils.data.IterableDataset":
|
||||
"""Return a Torch data iterator over this dataset.
|
||||
def to_torch(self,
|
||||
label_column: str,
|
||||
feature_columns: Optional[List[str]] = None,
|
||||
label_column_dtype: Optional["torch.dtype"] = None,
|
||||
feature_column_dtypes: Optional[List["torch.dtype"]] = None,
|
||||
prefetch_blocks: int = 0) -> \
|
||||
"torch.utils.data.IterableDataset":
|
||||
"""Return a Torch IterableDataset over this dataset.
|
||||
|
||||
Each element in IterableDataset will be a list consisting of 2
|
||||
elements. The first item is a list of the feature tensors. The
|
||||
second item is the label tensor. Each tensor will be of shape (N,
|
||||
1), where N is the ``batch_size`` used by the DataLoader.
|
||||
|
||||
Note that you probably want to call ``.split()`` on this dataset if
|
||||
there are to be multiple Torch workers consuming the data.
|
||||
|
||||
Time complexity: O(1)
|
||||
|
||||
Args:
|
||||
label_column (str): The name of the column used as the label
|
||||
(second element of the output list).
|
||||
feature_columns (Optional[List[str]]): The names of the columns
|
||||
to use as the features. If None, then use all columns
|
||||
except the label columns as the features.
|
||||
label_column_dtype (Optional[torch.dtype]): The torch dtype to
|
||||
use for the label column. If None, then automatically infer
|
||||
the dtype.
|
||||
feature_column_dtypes (Optional[List[torch.dtype]]): The dtypes
|
||||
to use for the feature columns. The len of this list must
|
||||
be equal to the len of ``feature_columns``. If None,
|
||||
then automatically infer the dtype.
|
||||
prefetch_blocks (int): The number of blocks to prefetch ahead of
|
||||
the current block during the scan.
|
||||
|
||||
Returns:
|
||||
A torch IterableDataset.
|
||||
"""
|
||||
raise NotImplementedError # P1
|
||||
import torch
|
||||
|
||||
from ray.experimental.data.impl.torch_iterable_dataset import \
|
||||
TorchIterableDataset
|
||||
|
||||
if feature_columns and feature_column_dtypes:
|
||||
if len(feature_columns) != len(feature_column_dtypes):
|
||||
raise ValueError("The lengths of `feature_columns` "
|
||||
f"({len(feature_columns)}) and "
|
||||
f"`feature_column_dtypes` ("
|
||||
f"{len(feature_column_dtypes)}) do not "
|
||||
"match!")
|
||||
|
||||
def make_generator():
|
||||
for batch in self.iter_batches(prefetch_blocks=prefetch_blocks):
|
||||
label_vals = batch.pop(label_column).values
|
||||
label_tensor = torch.as_tensor(
|
||||
label_vals, dtype=label_column_dtype)
|
||||
label_tensor = label_tensor.view(-1, 1)
|
||||
|
||||
feature_tensor = []
|
||||
if feature_columns:
|
||||
batch = batch[feature_columns]
|
||||
|
||||
if feature_column_dtypes:
|
||||
dtypes = feature_column_dtypes
|
||||
else:
|
||||
dtypes = [None] * len(batch.columns)
|
||||
|
||||
for col, dtype in zip(batch.columns, dtypes):
|
||||
col_vals = batch[col].values
|
||||
t = torch.as_tensor(col_vals, dtype=dtype)
|
||||
t = t.view(-1, 1)
|
||||
feature_tensor.append(t)
|
||||
|
||||
num_rows = batch.shape[0]
|
||||
for i in range(num_rows):
|
||||
features = [tensor[i] for tensor in feature_tensor]
|
||||
label = label_tensor[i]
|
||||
yield (features, label)
|
||||
|
||||
return TorchIterableDataset(make_generator)
|
||||
|
||||
def to_tf(self,
|
||||
label_column: str,
|
||||
|
|
21
python/ray/experimental/data/impl/torch_iterable_dataset.py
Normal file
21
python/ray/experimental/data/impl/torch_iterable_dataset.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class TorchIterableDataset(IterableDataset):
|
||||
def __init__(self, generator_func):
|
||||
self.generator_func = generator_func
|
||||
|
||||
def __iter__(self):
|
||||
it = self.generator_func()
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
yield from it
|
||||
else:
|
||||
# Multiple workers are doing dataloading.
|
||||
# Each worker has a copy of the data.
|
||||
# Avoid duplicates.
|
||||
import itertools
|
||||
it = itertools.islice(it, worker_info.id, None,
|
||||
worker_info.num_workers)
|
||||
yield from it
|
|
@ -11,6 +11,9 @@ import pyarrow as pa
|
|||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
|
||||
from ray.util.dask import ray_dask_get
|
||||
|
@ -755,6 +758,88 @@ def test_to_tf_feature_columns(ray_start_regular_shared):
|
|||
assert np.array_equal(df.values, combined_iterations)
|
||||
|
||||
|
||||
def test_to_torch(ray_start_regular_shared):
|
||||
df1 = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": [1.0, 2.0, 3.0],
|
||||
"label": [1.0, 2.0, 3.0]
|
||||
})
|
||||
df2 = pd.DataFrame({
|
||||
"one": [4, 5, 6],
|
||||
"two": [4.0, 5.0, 6.0],
|
||||
"label": [4.0, 5.0, 6.0]
|
||||
})
|
||||
df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]})
|
||||
df = pd.concat([df1, df2, df3])
|
||||
ds = ray.experimental.data.from_pandas(
|
||||
[ray.put(df1), ray.put(df2), ray.put(df3)])
|
||||
torchd = ds.to_torch(label_column="label")
|
||||
|
||||
dataloader = DataLoader(torchd, batch_size=3)
|
||||
|
||||
num_epochs = 2
|
||||
for _ in range(num_epochs):
|
||||
iterations = []
|
||||
for batch in iter(dataloader):
|
||||
iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy())
|
||||
combined_iterations = np.concatenate(iterations)
|
||||
assert np.array_equal(np.sort(df.values), np.sort(combined_iterations))
|
||||
|
||||
|
||||
def test_to_torch_multiple_workers(ray_start_regular_shared):
|
||||
df1 = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": [1.0, 2.0, 3.0],
|
||||
"label": [1.0, 2.0, 3.0]
|
||||
})
|
||||
df2 = pd.DataFrame({
|
||||
"one": [4, 5, 6],
|
||||
"two": [4.0, 5.0, 6.0],
|
||||
"label": [4.0, 5.0, 6.0]
|
||||
})
|
||||
df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]})
|
||||
df = pd.concat([df1, df2, df3])
|
||||
ds = ray.experimental.data.from_pandas(
|
||||
[ray.put(df1), ray.put(df2), ray.put(df3)])
|
||||
torchd = ds.to_torch(label_column="label")
|
||||
|
||||
dataloader = DataLoader(torchd, batch_size=1, num_workers=2)
|
||||
|
||||
iterations = []
|
||||
for batch in iter(dataloader):
|
||||
numpy_batch = torch.cat((*batch[0], batch[1]), axis=1).numpy()
|
||||
assert np.all(np.isin(numpy_batch, df.values))
|
||||
iterations.append(numpy_batch)
|
||||
|
||||
assert len(iterations) == len(df.values)
|
||||
|
||||
|
||||
def test_to_torch_feature_columns(ray_start_regular_shared):
|
||||
df1 = pd.DataFrame({
|
||||
"one": [1, 2, 3],
|
||||
"two": [1.0, 2.0, 3.0],
|
||||
"label": [1.0, 2.0, 3.0]
|
||||
})
|
||||
df2 = pd.DataFrame({
|
||||
"one": [4, 5, 6],
|
||||
"two": [4.0, 5.0, 6.0],
|
||||
"label": [4.0, 5.0, 6.0]
|
||||
})
|
||||
df3 = pd.DataFrame({"one": [7, 8], "two": [7.0, 8.0], "label": [7.0, 8.0]})
|
||||
df = pd.concat([df1, df2, df3]).drop("two", axis=1)
|
||||
ds = ray.experimental.data.from_pandas(
|
||||
[ray.put(df1), ray.put(df2), ray.put(df3)])
|
||||
torchd = ds.to_torch("label", feature_columns=["one"])
|
||||
iterations = []
|
||||
|
||||
dataloader = DataLoader(torchd, batch_size=3)
|
||||
|
||||
for batch in iter(dataloader):
|
||||
iterations.append(torch.cat((*batch[0], batch[1]), axis=1).numpy())
|
||||
combined_iterations = np.concatenate(iterations)
|
||||
assert np.array_equal(df.values, combined_iterations)
|
||||
|
||||
|
||||
def test_json_read(ray_start_regular_shared, tmp_path):
|
||||
# Single file.
|
||||
df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
|
||||
|
|
Loading…
Add table
Reference in a new issue