Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (#26386)" (#26483)

This reverts commit 36229d1234.
This commit is contained in:
Amog Kamsetty 2022-07-12 17:18:46 -07:00 committed by GitHub
parent 980a59477d
commit e6c04031fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 36 deletions

View file

@ -42,13 +42,6 @@ def ray_start_1_cpu_1_gpu():
ray.shutdown()
class LinearDatasetDict(LinearDataset):
"""Modifies the LinearDataset to return a Dict instead of a Tuple."""
def __getitem__(self, index):
return {"x": self.x[index, None], "y": self.y[index, None]}
# TODO: Refactor as a backend test.
@pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1])
def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker):
@ -99,9 +92,8 @@ def test_torch_prepare_model(ray_start_4_cpus_2_gpus):
# TODO: Refactor as a backend test.
@pytest.mark.parametrize("dataset", (LinearDataset, LinearDatasetDict))
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset):
data_loader = DataLoader(dataset(a=1, b=2, size=10))
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus):
data_loader = DataLoader(LinearDataset(a=1, b=2, size=10))
def train_fn():
wrapped_data_loader = train.torch.prepare_data_loader(data_loader)
@ -110,20 +102,12 @@ def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset):
assert isinstance(wrapped_data_loader.sampler, DistributedSampler)
# Make sure you can properly iterate through the DataLoader.
# Case where the dataset returns a tuple or list from __getitem__.
if isinstance(wrapped_data_loader.dataset[0], (tuple, list)):
for batch in wrapped_data_loader:
x = batch[0]
y = batch[1]
for batch in wrapped_data_loader:
X = batch[0]
y = batch[1]
# Make sure the data is on the correct device.
assert x.is_cuda and y.is_cuda
# Case where the dataset returns a dict from __getitem__.
elif isinstance(wrapped_data_loader.dataset[0], dict):
for batch in wrapped_data_loader:
for x, y in zip(batch["x"], batch["y"]):
# Make sure the data is on the correct device.
assert x.is_cuda and y.is_cuda
# Make sure the data is on the correct device.
assert X.is_cuda and y.is_cuda
trainer = Trainer("torch", num_workers=2, use_gpu=True)
trainer.start()

View file

@ -4,7 +4,6 @@ import os
import random
import types
import warnings
import collections
from pathlib import Path
from typing import Any, Dict, Optional
@ -549,18 +548,7 @@ class _WrappedDataLoader(DataLoader):
return i
with torch.cuda.stream(self._memcpy_stream):
if isinstance(item, collections.abc.Mapping):
item_on_device = {k: self._move_to_device(v) for k, v in item.items()}
elif isinstance(item, (tuple, list)):
item_on_device = type(item)(self._move_to_device(i) for i in item)
elif isinstance(item, torch.Tensor):
item_on_device = try_move_device(item)
else:
logger.info(
f"Data type {type(item)} doesn't support being moved to device."
)
return item_on_device
return tuple(try_move_device(i) for i in item)
def _wait_for_batch(self, item):
if self._memcpy_stream is None: