From 68670e375dacccdfa88950f06e40ecebee88daef Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 26 Jul 2022 11:59:41 -0700 Subject: [PATCH] Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader"" (#26491) Signed-off-by: Amog Kamsetty * Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (#26386)" (#26483)" This reverts commit e6c04031fd0b495ce88f261495fe957a8164358e. --- python/ray/train/tests/test_gpu.py | 53 +++++++++++++++++----- python/ray/train/torch/train_loop_utils.py | 17 ++++++- 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/python/ray/train/tests/test_gpu.py b/python/ray/train/tests/test_gpu.py index 15c52917b..766d4e89a 100644 --- a/python/ray/train/tests/test_gpu.py +++ b/python/ray/train/tests/test_gpu.py @@ -6,10 +6,6 @@ from unittest.mock import patch import pytest import torch import torchvision -from test_tune import ( - torch_fashion_mnist, - tune_tensorflow_mnist, -) from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler @@ -31,6 +27,10 @@ from ray.train.examples.torch_fashion_mnist_example import ( ) from ray.train.examples.torch_linear_example import LinearDataset from ray.train.horovod.horovod_trainer import HorovodTrainer +from ray.train.tests.test_tune import ( + torch_fashion_mnist, + tune_tensorflow_mnist, +) from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer from ray.train.torch import TorchConfig from ray.train.torch.torch_trainer import TorchTrainer @@ -65,6 +65,20 @@ def ray_2_node_4_gpu(): cluster.shutdown() +class LinearDatasetDict(LinearDataset): + """Modifies the LinearDataset to return a Dict instead of a Tuple.""" + + def __getitem__(self, index): + return {"x": self.x[index, None], "y": self.y[index, None]} + + +class NonTensorDataset(LinearDataset): + """Modifies the LinearDataset to also return non-tensor objects.""" + + def __getitem__(self, index): + return {"x": self.x[index, None], "y": 2} + + # TODO: Refactor as a backend test. @pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1]) def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker): @@ -149,8 +163,11 @@ def test_torch_prepare_model(ray_start_4_cpus_2_gpus): # TODO: Refactor as a backend test. -def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus): - data_loader = DataLoader(LinearDataset(a=1, b=2, size=10)) +@pytest.mark.parametrize( + "dataset", (LinearDataset, LinearDatasetDict, NonTensorDataset) +) +def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset): + data_loader = DataLoader(dataset(a=1, b=2, size=10)) def train_fn(): wrapped_data_loader = train.torch.prepare_data_loader(data_loader) @@ -159,12 +176,26 @@ def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus): assert isinstance(wrapped_data_loader.sampler, DistributedSampler) # Make sure you can properly iterate through the DataLoader. - for batch in wrapped_data_loader: - X = batch[0] - y = batch[1] + # Case where the dataset returns a tuple or list from __getitem__. + if isinstance(dataset, LinearDataset): + for batch in wrapped_data_loader: + x = batch[0] + y = batch[1] - # Make sure the data is on the correct device. - assert X.is_cuda and y.is_cuda + # Make sure the data is on the correct device. + assert x.is_cuda and y.is_cuda + # Case where the dataset returns a dict from __getitem__. + elif isinstance(dataset, LinearDatasetDict): + for batch in wrapped_data_loader: + for x, y in zip(batch["x"], batch["y"]): + # Make sure the data is on the correct device. + assert x.is_cuda and y.is_cuda + + elif isinstance(dataset, NonTensorDataset): + for batch in wrapped_data_loader: + for x, y in zip(batch["x"], batch["y"]): + # Make sure the data is on the correct device. + assert x.is_cuda and y == 2 trainer = Trainer("torch", num_workers=2, use_gpu=True) trainer.start() diff --git a/python/ray/train/torch/train_loop_utils.py b/python/ray/train/torch/train_loop_utils.py index b2b5cd611..6025fc2d7 100644 --- a/python/ray/train/torch/train_loop_utils.py +++ b/python/ray/train/torch/train_loop_utils.py @@ -4,6 +4,7 @@ import os import random import types import warnings +import collections from pathlib import Path from typing import Any, Dict, Optional @@ -585,7 +586,21 @@ class _WrappedDataLoader(DataLoader): return i with torch.cuda.stream(self._memcpy_stream): - return tuple(try_move_device(i) for i in item) + if isinstance(item, collections.abc.Mapping): + item_on_device = {k: self._move_to_device(v) for k, v in item.items()} + elif isinstance(item, tuple): + item_on_device = tuple(self._move_to_device(i) for i in item) + elif isinstance(item, list): + item_on_device = [self._move_to_device(i) for i in item] + elif isinstance(item, torch.Tensor): + item_on_device = try_move_device(item) + else: + logger.info( + f"Data type {type(item)} doesn't support being moved to device." + ) + item_on_device = item + + return item_on_device def _wait_for_batch(self, item): if self._memcpy_stream is None: