[AIR/Train] Make Dataset ingest configurable (#24066)

Refactors Dataset splitting to make it less hacky and address the TODO. Also makes Dataset ingest in general configurable for Ray Train. This is an internal only change for now, but will set the stage for the proposed ingest API

Customizable ingest for GBDT Trainers is out of scope for this PR.
This commit is contained in:
Amog Kamsetty 2022-04-27 21:41:44 -07:00 committed by GitHub
parent abba263f4e
commit 629424f489
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 194 additions and 97 deletions

View file

@ -1,10 +1,11 @@
import inspect
import logging
from pathlib import Path
from typing import Dict, Callable, Optional, Union
from typing import Dict, Callable, List, Optional, Union, TYPE_CHECKING
import ray
from ray import tune
from ray.actor import ActorHandle
from ray.ml.constants import TRAIN_DATASET_KEY, PREPROCESSOR_KEY
from ray.ml.trainer import Trainer
from ray.ml.config import ScalingConfig, RunConfig
@ -14,9 +15,13 @@ from ray.ml.checkpoint import Checkpoint
from ray.train import BackendConfig, TrainingIterator
from ray.train.backend import BackendExecutor
from ray.train.checkpoint import TuneCheckpointManager
from ray.train.impl.dataset_spec import _RayDatasetSpec
from ray.train.utils import construct_train_func
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
from ray.data import Dataset
logger = logging.getLogger(__name__)
@ -292,19 +297,9 @@ class DataParallelTrainer(Trainer):
else:
resume_checkpoint_dict = None
# Tell Ray Train to only shard the train dataset and not the other datasets.
# This is purely an implementation detail and users do not need to know about
# this.
# TODO(amog): Refactor this to remove hack and make this more modular.
# TrainingIterator should accept a generic custom_ingest_func that contains
# the logic for how to split the Datasets.
updated_dataset_dict = {}
for key, value in self.datasets.items():
if key == TRAIN_DATASET_KEY:
updated_dataset_dict[key] = value
else:
# Ray Train will strip out the added string before exposing to users.
updated_dataset_dict[key + "_NO-SHARD"] = value
dataset_spec = _RayDatasetSpec(
dataset_or_dict=self.datasets, dataset_split_fn=_default_dataset_split_fn
)
# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead
# of just a Dict.
@ -312,7 +307,7 @@ class DataParallelTrainer(Trainer):
backend_executor=backend_executor,
backend_config=self.backend_config,
train_func=train_loop_per_worker,
dataset=updated_dataset_dict if len(updated_dataset_dict) > 0 else None,
dataset_spec=dataset_spec,
checkpoint_manager=checkpoint_manager,
checkpoint=resume_checkpoint_dict,
checkpoint_strategy=None,
@ -348,3 +343,39 @@ class _DataParallelCheckpointManager(TuneCheckpointManager):
@property
def latest_checkpoint_dir(self) -> Optional[Path]:
raise NotImplementedError
def _default_dataset_split_fn(
dataset_dict: Dict[str, "Dataset"], training_worker_handles: List[ActorHandle]
) -> List[Dict[str, "Dataset"]]:
"""Defines splitting logic of Datasets passed into ``DataParallelTrainer``.
By default only training dataset will be split. All other datasets will not be
split and passed through directly to the training workers. This is because
validation implementation is often done on just the rank 0 worker.
Args:
dataset_dict: A dictionary of Datasets.
training_worker_handles: The actor handles of the training workers to use for
locality hints.
Returns:
A list of dataset dictionaries for each training worker.
"""
dataset_dict_splits = [{} for _ in range(len(training_worker_handles))]
for key, dataset in dataset_dict.items():
if key == TRAIN_DATASET_KEY:
dataset_splits = dataset.split(
len(training_worker_handles),
equal=True,
locality_hints=training_worker_handles,
)
else:
# Only shard the training dataset.
dataset_splits = [dataset] * len(training_worker_handles)
for i in range(len(dataset_splits)):
dataset_dict_splits[i][key] = dataset_splits[i]
return dataset_dict_splits

View file

@ -257,9 +257,9 @@ class Trainer(abc.ABC):
If the ``Trainer`` has both a datasets dict and
a preprocessor, the datasets dict contains a training dataset (denoted by
the "train" key), and the preprocessor has not yet
been fit, then it will be fit on the train.
been fit, then it will be fit on the train dataset.
Then, the Trainer's datasets will be transformed by the preprocessor.
Then, all Trainer's datasets will be transformed by the preprocessor.
The transformed datasets will be set back in the ``self.datasets`` attribute
of the Trainer to be used when overriding ``training_loop``.

View file

@ -1,7 +1,7 @@
import logging
import os
from collections import defaultdict
from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple
from typing import Callable, TypeVar, List, Optional, Dict, Type, Tuple
import ray
from ray.exceptions import RayActorError
@ -12,9 +12,10 @@ from ray.train.constants import (
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
TRAIN_ENABLE_WORKER_SPREAD_ENV,
)
from ray.train.impl.dataset_spec import _RayDatasetSpec
from ray.train.session import TrainingResult
from ray.train.session import init_session, get_session, shutdown_session
from ray.train.utils import RayDataset, check_for_failure, Singleton
from ray.train.utils import check_for_failure, Singleton
from ray.train.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
from ray.util.placement_group import get_current_placement_group, remove_placement_group
@ -314,42 +315,10 @@ class BackendExecutor:
ip_dict[node_ip] += 1
return rank_mapping
def _get_dataset_shards(self, dataset_or_dict):
if dataset_or_dict is None:
# Return None for each shard.
return [None] * len(self.worker_group)
def split_dataset(dataset_or_pipeline):
actors = [worker.actor for worker in self.worker_group.workers]
return dataset_or_pipeline.split(
len(self.worker_group), equal=True, locality_hints=actors
)
if isinstance(dataset_or_dict, dict):
# Return a smaller dict for each shard.
dataset_shards = [{} for _ in range(len(self.worker_group))]
# TODO(amog): Update Backend to accept a generic function with logic on
# how to split dataset, instead of having to support _NO-SHARD in key.
for key, dataset in dataset_or_dict.items():
if "_NO-SHARD" in key:
# Do not shard this dataset.
split_datasets = [dataset] * len(self.worker_group)
key = key.replace("_NO-SHARD", "")
else:
split_datasets = split_dataset(dataset)
assert len(split_datasets) == len(self.worker_group)
for i in range(len(split_datasets)):
dataset_shards[i][key] = split_datasets[i]
return dataset_shards
else:
# return a smaller RayDataset for each shard.
return split_dataset(dataset_or_dict)
def start_training(
self,
train_func: Callable[[], T],
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
dataset_spec: _RayDatasetSpec,
checkpoint: Optional[Dict] = None,
) -> None:
"""Executes a training function on all workers in a separate thread.
@ -357,17 +326,11 @@ class BackendExecutor:
``finish_training`` should be called after this.
Args:
train_func (Callable): The training function to run on each worker.
dataset (Optional[Union[Dataset, DatasetPipeline]])
Distributed Ray Dataset or DatasetPipeline to pass into
worker, which can be accessed from the training function via
``train.get_dataset_shard()``. Sharding will automatically be
handled by the Trainer. Multiple Datasets can be passed in as
a ``Dict`` that maps each name key to a Dataset value,
and each Dataset can be accessed from the training function
by passing in a `dataset_name` argument to
``train.get_dataset_shard()``.
checkpoint (Optional[Dict]): The checkpoint data that
train_func: The training function to run on each worker.
dataset_spec: A specification for the Ray Dataset to be
passed to the training workers, and the logic on how to shard the Ray
Dataset.
checkpoint: The checkpoint data that
should be loaded onto each worker and accessed by the
training function via ``train.load_checkpoint()``. If this
is ``None`` then no checkpoint will be loaded.
@ -406,7 +369,8 @@ class BackendExecutor:
)
if self.dataset_shards is None:
self.dataset_shards = self._get_dataset_shards(dataset)
actors = [worker.actor for worker in self.worker_group.workers]
self.dataset_shards = dataset_spec.get_dataset_shards(actors)
local_rank_map = self._create_local_rank_map()

View file

@ -0,0 +1,93 @@
from dataclasses import dataclass
from typing import Optional, Union, Dict, Callable, List, TYPE_CHECKING
from ray.actor import ActorHandle
if TYPE_CHECKING:
from ray.data import Dataset, DatasetPipeline
RayDataset = Union["Dataset", "DatasetPipeline"]
@dataclass
class _RayDatasetSpec:
"""Configuration for Ray Datasets to pass to the training workers.
dataset_or_dict: An optional Ray Dataset (or DatasetPipeline) or a dictionary of
datasets to be sharded across all the training workers, which can be accessed
from the training function via ``train.get_dataset_shard()``. Multiple Datasets
can be passed in as a dictionary that maps each name key to a Dataset value,
and each Dataset can be accessed from the training function by passing in a
`dataset_name` argument to ``train.get_dataset_shard()``.
dataset_split_fn: An optional callable to specify how the provided ``dataset``
should be split across the training workers. It is expected to take in two
arguments. The first one is the ``dataset``, just as is passed in to the
``_RayDatasetSpec``. The second argument is a list of the ActorHandles of the
training workers (to use as locality hints). The Callable is expected to
return a list of RayDatasets or a list of dictionaries of RayDatasets,
with the length of the list equal to the length of the list of actor handles.
If None is provided, the provided Ray Dataset(s) will be simply be split using
the actor handles as locality hints.
"""
dataset_or_dict: Optional[Union[RayDataset, Dict[str, RayDataset]]]
dataset_split_fn: Optional[
Callable[
[Union[RayDataset, Dict[str, RayDataset]], List[ActorHandle]],
List[Union[RayDataset, Dict[str, RayDataset]]],
]
] = None
def _default_split_fn(
self, training_worker_handles: List[ActorHandle]
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]:
def split_dataset(dataset_or_pipeline):
return dataset_or_pipeline.split(
len(training_worker_handles),
equal=True,
locality_hints=training_worker_handles,
)
if isinstance(self.dataset_or_dict, dict):
# Return a smaller dict for each shard.
dataset_shards = [{} for _ in range(len(training_worker_handles))]
for key, dataset in self.dataset_or_dict.items():
split_datasets = split_dataset(dataset)
assert len(split_datasets) == len(training_worker_handles)
for i in range(len(split_datasets)):
dataset_shards[i][key] = split_datasets[i]
return dataset_shards
else:
# return a smaller RayDataset for each shard.
return split_dataset(self.dataset_or_dict)
def get_dataset_shards(
self, training_worker_handles: List[ActorHandle]
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]:
"""Returns Dataset splits based off the spec and the given training workers
Args:
training_worker_handles: A list of the training worker actor handles.
Returns:
A list of RayDataset shards or list of dictionaries of RayDataset shards,
one for each training worker.
"""
if not self.dataset_or_dict:
return [None] * len(training_worker_handles)
if self.dataset_split_fn is None:
return self._default_split_fn(training_worker_handles)
else:
splits = self.dataset_split_fn(
self.dataset_or_dict, training_worker_handles
)
if not len(splits) == len(training_worker_handles):
raise RuntimeError(
"The list of Datasets returned by the "
f"`dataset_split_fn`: {len(splits)} does not match "
f"the number of training workers: {len(training_worker_handles)}"
)
return splits

View file

@ -25,7 +25,8 @@ from ray.train.constants import (
RESULT_FETCH_TIMEOUT,
SESSION_MISUSE_LOG_ONCE_KEY,
)
from ray.train.utils import PropagatingThread, RayDataset
from ray.train.utils import PropagatingThread
from ray.train.impl.dataset_spec import RayDataset
from ray.util import PublicAPI, log_once

View file

@ -14,6 +14,7 @@ from ray.train.backend import (
TrainingWorkerError,
)
from ray.train.backend import BackendConfig, BackendExecutor
from ray.train.impl.dataset_spec import _RayDatasetSpec
from ray.train.tensorflow import TensorflowConfig
from ray.train.torch import TorchConfig
from ray.train.constants import (
@ -102,11 +103,14 @@ class TestBackend(Backend):
pass
EMPTY_RAY_DATASET_SPEC = _RayDatasetSpec(dataset_or_dict=None)
def test_start(ray_start_2_cpus):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
with pytest.raises(InactiveWorkerGroupError):
e.start_training(lambda: 1)
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
e.start()
assert len(e.worker_group) == 2
@ -127,7 +131,7 @@ def test_initialization_hook(ray_start_2_cpus):
return os.getenv("TEST", "0")
e.start_training(check)
e.start_training(check, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert e.finish_training() == ["1", "1"]
@ -138,7 +142,7 @@ def test_shutdown(ray_start_2_cpus):
assert len(e.worker_group) == 2
e.shutdown()
with pytest.raises(InactiveWorkerGroupError):
e.start_training(lambda: 1)
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
def test_train(ray_start_2_cpus):
@ -146,7 +150,7 @@ def test_train(ray_start_2_cpus):
e = BackendExecutor(config, num_workers=2)
e.start()
e.start_training(lambda: 1)
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert e.finish_training() == [1, 1]
@ -158,7 +162,7 @@ def test_local_ranks(ray_start_2_cpus):
def train_func():
return train.local_rank()
e.start_training(train_func)
e.start_training(train_func, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert set(e.finish_training()) == {0, 1}
@ -176,10 +180,10 @@ def test_train_failure(ray_start_2_cpus):
with pytest.raises(TrainBackendError):
e.finish_training()
e.start_training(lambda: 1)
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
with pytest.raises(TrainBackendError):
e.start_training(lambda: 2)
e.start_training(lambda: 2, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert e.finish_training() == [1, 1]
@ -195,7 +199,7 @@ def test_worker_failure(ray_start_2_cpus):
new_execute_func = gen_execute_special(train_fail)
with patch.object(WorkerGroup, "execute_async", new_execute_func):
with pytest.raises(TrainingWorkerError):
e.start_training(lambda: 1)
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
e.finish_training()
@ -209,7 +213,7 @@ def test_mismatch_checkpoint_report(ray_start_2_cpus):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
e.start_training(train_func)
e.start_training(train_func, dataset_spec=EMPTY_RAY_DATASET_SPEC)
with pytest.raises(RuntimeError):
e.get_next_results()
@ -226,7 +230,7 @@ def test_tensorflow_start(ray_start_2_cpus):
return json.loads(os.environ["TF_CONFIG"])
e.start_training(get_tf_config)
e.start_training(get_tf_config, dataset_spec=EMPTY_RAY_DATASET_SPEC)
results = e.finish_training()
assert len(results) == num_workers
@ -251,12 +255,12 @@ def test_torch_start_shutdown(ray_start_2_cpus, init_method):
and torch.distributed.get_world_size() == 2
)
e.start_training(check_process_group)
e.start_training(check_process_group, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert all(e.finish_training())
e._backend.on_shutdown(e.worker_group, e._backend_config)
e.start_training(check_process_group)
e.start_training(check_process_group, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert not any(e.finish_training())
@ -282,7 +286,7 @@ def test_cuda_visible_devices(ray_2_node_2_gpu, worker_results):
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=1
)
e.start()
e.start_training(get_resources)
e.start_training(get_resources, dataset_spec=EMPTY_RAY_DATASET_SPEC)
results = e.finish_training()
results.sort()
assert results == expected_results
@ -314,7 +318,7 @@ def test_cuda_visible_devices_fractional(ray_2_node_2_gpu, worker_results):
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=0.5
)
e.start()
e.start_training(get_resources)
e.start_training(get_resources, dataset_spec=EMPTY_RAY_DATASET_SPEC)
results = e.finish_training()
results.sort()
assert results == expected_results
@ -342,7 +346,7 @@ def test_cuda_visible_devices_multiple(ray_2_node_4_gpu, worker_results):
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=2
)
e.start()
e.start_training(get_resources)
e.start_training(get_resources, dataset_spec=EMPTY_RAY_DATASET_SPEC)
results = e.finish_training()
results.sort()
assert results == expected_results
@ -393,7 +397,7 @@ def test_placement_group_parent(ray_4_node_4_cpu, placement_group_capture_child_
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
e.start_training(train_func)
e.start_training(train_func, dataset_spec=EMPTY_RAY_DATASET_SPEC)
return e.finish_training()
results_future = test.options(

View file

@ -16,8 +16,12 @@ from ray.train.backend import (
TrainingWorkerError,
)
from ray.train.callbacks.callback import TrainingCallback
from ray.train.impl.dataset_spec import RayDataset, _RayDatasetSpec
from ray.train.session import TrainingResultType
from ray.train.utils import RayDataset, construct_train_func, ActorWrapper
from ray.train.utils import (
construct_train_func,
ActorWrapper,
)
from ray.train.checkpoint import (
CheckpointStrategy,
TuneCheckpointManager,
@ -336,12 +340,14 @@ class Trainer:
train_func = construct_train_func(train_func, config)
dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset)
try:
iterator = TrainingIterator(
backend_executor=self._backend_executor,
backend_config=self._backend_config,
train_func=train_func,
dataset=dataset,
dataset_spec=dataset_spec,
checkpoint_manager=self.checkpoint_manager,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
@ -413,12 +419,14 @@ class Trainer:
train_func = construct_train_func(train_func, config)
dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset)
return TrainingIterator(
backend_executor=self._backend_executor,
backend_config=self._backend_config,
train_func=train_func,
run_dir=self.latest_run_dir,
dataset=dataset,
dataset_spec=dataset_spec,
checkpoint_manager=self.checkpoint_manager,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
@ -513,7 +521,7 @@ class Trainer:
"""Creates a Tune ``Trainable`` from the input training function.
Args:
func (Callable): The function that should be executed on each
train_func (Callable): The function that should be executed on each
training worker.
dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]):
Distributed Ray p:ref:`Dataset <dataset-api>` or
@ -650,7 +658,7 @@ class TrainingIterator:
backend_executor: Union[BackendExecutor, ActorWrapper],
backend_config: BackendConfig,
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]],
dataset_spec: _RayDatasetSpec,
checkpoint_manager: CheckpointManager,
checkpoint: Optional[Union[Dict, str, Path]],
checkpoint_strategy: Optional[CheckpointStrategy],
@ -659,14 +667,14 @@ class TrainingIterator:
self._backend_executor = backend_executor
self._backend = backend_config.backend_cls()
self._train_func = train_func
self._dataset = dataset
self._dataset_spec = dataset_spec
self._run_dir = run_dir
self._checkpoint_manager = checkpoint_manager
self._checkpoint_strategy = checkpoint_strategy
self._start_training(
train_func=train_func,
run_dir=run_dir,
dataset=dataset,
dataset_spec=self._dataset_spec,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
)
@ -681,7 +689,7 @@ class TrainingIterator:
self,
train_func,
run_dir,
dataset,
dataset_spec,
checkpoint,
checkpoint_strategy,
latest_checkpoint_id=None,
@ -694,7 +702,9 @@ class TrainingIterator:
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
self._run_with_error_handling(
lambda: self._backend_executor.start_training(
train_func=train_func, dataset=dataset, checkpoint=checkpoint_dict
train_func=train_func,
dataset_spec=dataset_spec,
checkpoint=checkpoint_dict,
)
)
@ -713,7 +723,7 @@ class TrainingIterator:
self._start_training(
self._train_func,
self._run_dir,
self._dataset,
self._dataset_spec,
self._checkpoint_manager.latest_checkpoint,
self._checkpoint_strategy,
latest_checkpoint_id=self._checkpoint_manager.latest_checkpoint_id,

View file

@ -10,7 +10,6 @@ from typing import (
Dict,
List,
Any,
TYPE_CHECKING,
Union,
Callable,
TypeVar,
@ -23,11 +22,6 @@ from ray.exceptions import RayActorError
from ray.types import ObjectRef
from ray.util.ml_utils.util import find_free_port
if TYPE_CHECKING:
from ray.data import Dataset
from ray.data.dataset_pipeline import DatasetPipeline
RayDataset = Union["Dataset", "DatasetPipeline"]
T = TypeVar("T")
logger = logging.getLogger(__name__)