Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (#26386)" (#26483)

This reverts commit 36229d1234.
This commit is contained in:
Amog Kamsetty 2022-07-12 17:18:46 -07:00 committed by GitHub
parent 980a59477d
commit e6c04031fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 36 deletions

View file

@ -42,13 +42,6 @@ def ray_start_1_cpu_1_gpu():
ray.shutdown() ray.shutdown()
class LinearDatasetDict(LinearDataset):
"""Modifies the LinearDataset to return a Dict instead of a Tuple."""
def __getitem__(self, index):
return {"x": self.x[index, None], "y": self.y[index, None]}
# TODO: Refactor as a backend test. # TODO: Refactor as a backend test.
@pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1]) @pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1])
def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker): def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker):
@ -99,9 +92,8 @@ def test_torch_prepare_model(ray_start_4_cpus_2_gpus):
# TODO: Refactor as a backend test. # TODO: Refactor as a backend test.
@pytest.mark.parametrize("dataset", (LinearDataset, LinearDatasetDict)) def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus):
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset): data_loader = DataLoader(LinearDataset(a=1, b=2, size=10))
data_loader = DataLoader(dataset(a=1, b=2, size=10))
def train_fn(): def train_fn():
wrapped_data_loader = train.torch.prepare_data_loader(data_loader) wrapped_data_loader = train.torch.prepare_data_loader(data_loader)
@ -110,20 +102,12 @@ def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset):
assert isinstance(wrapped_data_loader.sampler, DistributedSampler) assert isinstance(wrapped_data_loader.sampler, DistributedSampler)
# Make sure you can properly iterate through the DataLoader. # Make sure you can properly iterate through the DataLoader.
# Case where the dataset returns a tuple or list from __getitem__. for batch in wrapped_data_loader:
if isinstance(wrapped_data_loader.dataset[0], (tuple, list)): X = batch[0]
for batch in wrapped_data_loader: y = batch[1]
x = batch[0]
y = batch[1]
# Make sure the data is on the correct device. # Make sure the data is on the correct device.
assert x.is_cuda and y.is_cuda assert X.is_cuda and y.is_cuda
# Case where the dataset returns a dict from __getitem__.
elif isinstance(wrapped_data_loader.dataset[0], dict):
for batch in wrapped_data_loader:
for x, y in zip(batch["x"], batch["y"]):
# Make sure the data is on the correct device.
assert x.is_cuda and y.is_cuda
trainer = Trainer("torch", num_workers=2, use_gpu=True) trainer = Trainer("torch", num_workers=2, use_gpu=True)
trainer.start() trainer.start()

View file

@ -4,7 +4,6 @@ import os
import random import random
import types import types
import warnings import warnings
import collections
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -549,18 +548,7 @@ class _WrappedDataLoader(DataLoader):
return i return i
with torch.cuda.stream(self._memcpy_stream): with torch.cuda.stream(self._memcpy_stream):
if isinstance(item, collections.abc.Mapping): return tuple(try_move_device(i) for i in item)
item_on_device = {k: self._move_to_device(v) for k, v in item.items()}
elif isinstance(item, (tuple, list)):
item_on_device = type(item)(self._move_to_device(i) for i in item)
elif isinstance(item, torch.Tensor):
item_on_device = try_move_device(item)
else:
logger.info(
f"Data type {type(item)} doesn't support being moved to device."
)
return item_on_device
def _wait_for_batch(self, item): def _wait_for_batch(self, item):
if self._memcpy_stream is None: if self._memcpy_stream is None: