mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader"" (#26491)
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
* Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (#26386)" (#26483)"
This reverts commit e6c04031fd
.
This commit is contained in:
parent
5bcaf4ffcb
commit
68670e375d
2 changed files with 58 additions and 12 deletions
|
@ -6,10 +6,6 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
import torch
|
||||
import torchvision
|
||||
from test_tune import (
|
||||
torch_fashion_mnist,
|
||||
tune_tensorflow_mnist,
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
|
@ -31,6 +27,10 @@ from ray.train.examples.torch_fashion_mnist_example import (
|
|||
)
|
||||
from ray.train.examples.torch_linear_example import LinearDataset
|
||||
from ray.train.horovod.horovod_trainer import HorovodTrainer
|
||||
from ray.train.tests.test_tune import (
|
||||
torch_fashion_mnist,
|
||||
tune_tensorflow_mnist,
|
||||
)
|
||||
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer
|
||||
from ray.train.torch import TorchConfig
|
||||
from ray.train.torch.torch_trainer import TorchTrainer
|
||||
|
@ -65,6 +65,20 @@ def ray_2_node_4_gpu():
|
|||
cluster.shutdown()
|
||||
|
||||
|
||||
class LinearDatasetDict(LinearDataset):
|
||||
"""Modifies the LinearDataset to return a Dict instead of a Tuple."""
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {"x": self.x[index, None], "y": self.y[index, None]}
|
||||
|
||||
|
||||
class NonTensorDataset(LinearDataset):
|
||||
"""Modifies the LinearDataset to also return non-tensor objects."""
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {"x": self.x[index, None], "y": 2}
|
||||
|
||||
|
||||
# TODO: Refactor as a backend test.
|
||||
@pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1])
|
||||
def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker):
|
||||
|
@ -149,8 +163,11 @@ def test_torch_prepare_model(ray_start_4_cpus_2_gpus):
|
|||
|
||||
|
||||
# TODO: Refactor as a backend test.
|
||||
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus):
|
||||
data_loader = DataLoader(LinearDataset(a=1, b=2, size=10))
|
||||
@pytest.mark.parametrize(
|
||||
"dataset", (LinearDataset, LinearDatasetDict, NonTensorDataset)
|
||||
)
|
||||
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset):
|
||||
data_loader = DataLoader(dataset(a=1, b=2, size=10))
|
||||
|
||||
def train_fn():
|
||||
wrapped_data_loader = train.torch.prepare_data_loader(data_loader)
|
||||
|
@ -159,12 +176,26 @@ def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus):
|
|||
assert isinstance(wrapped_data_loader.sampler, DistributedSampler)
|
||||
|
||||
# Make sure you can properly iterate through the DataLoader.
|
||||
for batch in wrapped_data_loader:
|
||||
X = batch[0]
|
||||
y = batch[1]
|
||||
# Case where the dataset returns a tuple or list from __getitem__.
|
||||
if isinstance(dataset, LinearDataset):
|
||||
for batch in wrapped_data_loader:
|
||||
x = batch[0]
|
||||
y = batch[1]
|
||||
|
||||
# Make sure the data is on the correct device.
|
||||
assert X.is_cuda and y.is_cuda
|
||||
# Make sure the data is on the correct device.
|
||||
assert x.is_cuda and y.is_cuda
|
||||
# Case where the dataset returns a dict from __getitem__.
|
||||
elif isinstance(dataset, LinearDatasetDict):
|
||||
for batch in wrapped_data_loader:
|
||||
for x, y in zip(batch["x"], batch["y"]):
|
||||
# Make sure the data is on the correct device.
|
||||
assert x.is_cuda and y.is_cuda
|
||||
|
||||
elif isinstance(dataset, NonTensorDataset):
|
||||
for batch in wrapped_data_loader:
|
||||
for x, y in zip(batch["x"], batch["y"]):
|
||||
# Make sure the data is on the correct device.
|
||||
assert x.is_cuda and y == 2
|
||||
|
||||
trainer = Trainer("torch", num_workers=2, use_gpu=True)
|
||||
trainer.start()
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import random
|
||||
import types
|
||||
import warnings
|
||||
import collections
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
@ -585,7 +586,21 @@ class _WrappedDataLoader(DataLoader):
|
|||
return i
|
||||
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
return tuple(try_move_device(i) for i in item)
|
||||
if isinstance(item, collections.abc.Mapping):
|
||||
item_on_device = {k: self._move_to_device(v) for k, v in item.items()}
|
||||
elif isinstance(item, tuple):
|
||||
item_on_device = tuple(self._move_to_device(i) for i in item)
|
||||
elif isinstance(item, list):
|
||||
item_on_device = [self._move_to_device(i) for i in item]
|
||||
elif isinstance(item, torch.Tensor):
|
||||
item_on_device = try_move_device(item)
|
||||
else:
|
||||
logger.info(
|
||||
f"Data type {type(item)} doesn't support being moved to device."
|
||||
)
|
||||
item_on_device = item
|
||||
|
||||
return item_on_device
|
||||
|
||||
def _wait_for_batch(self, item):
|
||||
if self._memcpy_stream is None:
|
||||
|
|
Loading…
Add table
Reference in a new issue