mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[AIR/Train] Make Dataset ingest configurable (#24066)
Refactors Dataset splitting to make it less hacky and address the TODO. Also makes Dataset ingest in general configurable for Ray Train. This is an internal only change for now, but will set the stage for the proposed ingest API Customizable ingest for GBDT Trainers is out of scope for this PR.
This commit is contained in:
parent
abba263f4e
commit
629424f489
8 changed files with 194 additions and 97 deletions
|
@ -1,10 +1,11 @@
|
|||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Callable, Optional, Union
|
||||
from typing import Dict, Callable, List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.actor import ActorHandle
|
||||
from ray.ml.constants import TRAIN_DATASET_KEY, PREPROCESSOR_KEY
|
||||
from ray.ml.trainer import Trainer
|
||||
from ray.ml.config import ScalingConfig, RunConfig
|
||||
|
@ -14,9 +15,13 @@ from ray.ml.checkpoint import Checkpoint
|
|||
from ray.train import BackendConfig, TrainingIterator
|
||||
from ray.train.backend import BackendExecutor
|
||||
from ray.train.checkpoint import TuneCheckpointManager
|
||||
from ray.train.impl.dataset_spec import _RayDatasetSpec
|
||||
from ray.train.utils import construct_train_func
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -292,19 +297,9 @@ class DataParallelTrainer(Trainer):
|
|||
else:
|
||||
resume_checkpoint_dict = None
|
||||
|
||||
# Tell Ray Train to only shard the train dataset and not the other datasets.
|
||||
# This is purely an implementation detail and users do not need to know about
|
||||
# this.
|
||||
# TODO(amog): Refactor this to remove hack and make this more modular.
|
||||
# TrainingIterator should accept a generic custom_ingest_func that contains
|
||||
# the logic for how to split the Datasets.
|
||||
updated_dataset_dict = {}
|
||||
for key, value in self.datasets.items():
|
||||
if key == TRAIN_DATASET_KEY:
|
||||
updated_dataset_dict[key] = value
|
||||
else:
|
||||
# Ray Train will strip out the added string before exposing to users.
|
||||
updated_dataset_dict[key + "_NO-SHARD"] = value
|
||||
dataset_spec = _RayDatasetSpec(
|
||||
dataset_or_dict=self.datasets, dataset_split_fn=_default_dataset_split_fn
|
||||
)
|
||||
|
||||
# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead
|
||||
# of just a Dict.
|
||||
|
@ -312,7 +307,7 @@ class DataParallelTrainer(Trainer):
|
|||
backend_executor=backend_executor,
|
||||
backend_config=self.backend_config,
|
||||
train_func=train_loop_per_worker,
|
||||
dataset=updated_dataset_dict if len(updated_dataset_dict) > 0 else None,
|
||||
dataset_spec=dataset_spec,
|
||||
checkpoint_manager=checkpoint_manager,
|
||||
checkpoint=resume_checkpoint_dict,
|
||||
checkpoint_strategy=None,
|
||||
|
@ -348,3 +343,39 @@ class _DataParallelCheckpointManager(TuneCheckpointManager):
|
|||
@property
|
||||
def latest_checkpoint_dir(self) -> Optional[Path]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _default_dataset_split_fn(
|
||||
dataset_dict: Dict[str, "Dataset"], training_worker_handles: List[ActorHandle]
|
||||
) -> List[Dict[str, "Dataset"]]:
|
||||
"""Defines splitting logic of Datasets passed into ``DataParallelTrainer``.
|
||||
|
||||
By default only training dataset will be split. All other datasets will not be
|
||||
split and passed through directly to the training workers. This is because
|
||||
validation implementation is often done on just the rank 0 worker.
|
||||
|
||||
Args:
|
||||
dataset_dict: A dictionary of Datasets.
|
||||
training_worker_handles: The actor handles of the training workers to use for
|
||||
locality hints.
|
||||
|
||||
Returns:
|
||||
A list of dataset dictionaries for each training worker.
|
||||
"""
|
||||
dataset_dict_splits = [{} for _ in range(len(training_worker_handles))]
|
||||
|
||||
for key, dataset in dataset_dict.items():
|
||||
if key == TRAIN_DATASET_KEY:
|
||||
dataset_splits = dataset.split(
|
||||
len(training_worker_handles),
|
||||
equal=True,
|
||||
locality_hints=training_worker_handles,
|
||||
)
|
||||
else:
|
||||
# Only shard the training dataset.
|
||||
dataset_splits = [dataset] * len(training_worker_handles)
|
||||
|
||||
for i in range(len(dataset_splits)):
|
||||
dataset_dict_splits[i][key] = dataset_splits[i]
|
||||
|
||||
return dataset_dict_splits
|
||||
|
|
|
@ -257,9 +257,9 @@ class Trainer(abc.ABC):
|
|||
If the ``Trainer`` has both a datasets dict and
|
||||
a preprocessor, the datasets dict contains a training dataset (denoted by
|
||||
the "train" key), and the preprocessor has not yet
|
||||
been fit, then it will be fit on the train.
|
||||
been fit, then it will be fit on the train dataset.
|
||||
|
||||
Then, the Trainer's datasets will be transformed by the preprocessor.
|
||||
Then, all Trainer's datasets will be transformed by the preprocessor.
|
||||
|
||||
The transformed datasets will be set back in the ``self.datasets`` attribute
|
||||
of the Trainer to be used when overriding ``training_loop``.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple
|
||||
from typing import Callable, TypeVar, List, Optional, Dict, Type, Tuple
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
|
@ -12,9 +12,10 @@ from ray.train.constants import (
|
|||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV,
|
||||
)
|
||||
from ray.train.impl.dataset_spec import _RayDatasetSpec
|
||||
from ray.train.session import TrainingResult
|
||||
from ray.train.session import init_session, get_session, shutdown_session
|
||||
from ray.train.utils import RayDataset, check_for_failure, Singleton
|
||||
from ray.train.utils import check_for_failure, Singleton
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.placement_group import get_current_placement_group, remove_placement_group
|
||||
|
@ -314,42 +315,10 @@ class BackendExecutor:
|
|||
ip_dict[node_ip] += 1
|
||||
return rank_mapping
|
||||
|
||||
def _get_dataset_shards(self, dataset_or_dict):
|
||||
|
||||
if dataset_or_dict is None:
|
||||
# Return None for each shard.
|
||||
return [None] * len(self.worker_group)
|
||||
|
||||
def split_dataset(dataset_or_pipeline):
|
||||
actors = [worker.actor for worker in self.worker_group.workers]
|
||||
return dataset_or_pipeline.split(
|
||||
len(self.worker_group), equal=True, locality_hints=actors
|
||||
)
|
||||
|
||||
if isinstance(dataset_or_dict, dict):
|
||||
# Return a smaller dict for each shard.
|
||||
dataset_shards = [{} for _ in range(len(self.worker_group))]
|
||||
# TODO(amog): Update Backend to accept a generic function with logic on
|
||||
# how to split dataset, instead of having to support _NO-SHARD in key.
|
||||
for key, dataset in dataset_or_dict.items():
|
||||
if "_NO-SHARD" in key:
|
||||
# Do not shard this dataset.
|
||||
split_datasets = [dataset] * len(self.worker_group)
|
||||
key = key.replace("_NO-SHARD", "")
|
||||
else:
|
||||
split_datasets = split_dataset(dataset)
|
||||
assert len(split_datasets) == len(self.worker_group)
|
||||
for i in range(len(split_datasets)):
|
||||
dataset_shards[i][key] = split_datasets[i]
|
||||
return dataset_shards
|
||||
else:
|
||||
# return a smaller RayDataset for each shard.
|
||||
return split_dataset(dataset_or_dict)
|
||||
|
||||
def start_training(
|
||||
self,
|
||||
train_func: Callable[[], T],
|
||||
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
|
||||
dataset_spec: _RayDatasetSpec,
|
||||
checkpoint: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Executes a training function on all workers in a separate thread.
|
||||
|
@ -357,17 +326,11 @@ class BackendExecutor:
|
|||
``finish_training`` should be called after this.
|
||||
|
||||
Args:
|
||||
train_func (Callable): The training function to run on each worker.
|
||||
dataset (Optional[Union[Dataset, DatasetPipeline]])
|
||||
Distributed Ray Dataset or DatasetPipeline to pass into
|
||||
worker, which can be accessed from the training function via
|
||||
``train.get_dataset_shard()``. Sharding will automatically be
|
||||
handled by the Trainer. Multiple Datasets can be passed in as
|
||||
a ``Dict`` that maps each name key to a Dataset value,
|
||||
and each Dataset can be accessed from the training function
|
||||
by passing in a `dataset_name` argument to
|
||||
``train.get_dataset_shard()``.
|
||||
checkpoint (Optional[Dict]): The checkpoint data that
|
||||
train_func: The training function to run on each worker.
|
||||
dataset_spec: A specification for the Ray Dataset to be
|
||||
passed to the training workers, and the logic on how to shard the Ray
|
||||
Dataset.
|
||||
checkpoint: The checkpoint data that
|
||||
should be loaded onto each worker and accessed by the
|
||||
training function via ``train.load_checkpoint()``. If this
|
||||
is ``None`` then no checkpoint will be loaded.
|
||||
|
@ -406,7 +369,8 @@ class BackendExecutor:
|
|||
)
|
||||
|
||||
if self.dataset_shards is None:
|
||||
self.dataset_shards = self._get_dataset_shards(dataset)
|
||||
actors = [worker.actor for worker in self.worker_group.workers]
|
||||
self.dataset_shards = dataset_spec.get_dataset_shards(actors)
|
||||
|
||||
local_rank_map = self._create_local_rank_map()
|
||||
|
||||
|
|
93
python/ray/train/impl/dataset_spec.py
Normal file
93
python/ray/train/impl/dataset_spec.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, Dict, Callable, List, TYPE_CHECKING
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset, DatasetPipeline
|
||||
|
||||
RayDataset = Union["Dataset", "DatasetPipeline"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RayDatasetSpec:
|
||||
"""Configuration for Ray Datasets to pass to the training workers.
|
||||
|
||||
dataset_or_dict: An optional Ray Dataset (or DatasetPipeline) or a dictionary of
|
||||
datasets to be sharded across all the training workers, which can be accessed
|
||||
from the training function via ``train.get_dataset_shard()``. Multiple Datasets
|
||||
can be passed in as a dictionary that maps each name key to a Dataset value,
|
||||
and each Dataset can be accessed from the training function by passing in a
|
||||
`dataset_name` argument to ``train.get_dataset_shard()``.
|
||||
dataset_split_fn: An optional callable to specify how the provided ``dataset``
|
||||
should be split across the training workers. It is expected to take in two
|
||||
arguments. The first one is the ``dataset``, just as is passed in to the
|
||||
``_RayDatasetSpec``. The second argument is a list of the ActorHandles of the
|
||||
training workers (to use as locality hints). The Callable is expected to
|
||||
return a list of RayDatasets or a list of dictionaries of RayDatasets,
|
||||
with the length of the list equal to the length of the list of actor handles.
|
||||
If None is provided, the provided Ray Dataset(s) will be simply be split using
|
||||
the actor handles as locality hints.
|
||||
|
||||
"""
|
||||
|
||||
dataset_or_dict: Optional[Union[RayDataset, Dict[str, RayDataset]]]
|
||||
dataset_split_fn: Optional[
|
||||
Callable[
|
||||
[Union[RayDataset, Dict[str, RayDataset]], List[ActorHandle]],
|
||||
List[Union[RayDataset, Dict[str, RayDataset]]],
|
||||
]
|
||||
] = None
|
||||
|
||||
def _default_split_fn(
|
||||
self, training_worker_handles: List[ActorHandle]
|
||||
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]:
|
||||
def split_dataset(dataset_or_pipeline):
|
||||
return dataset_or_pipeline.split(
|
||||
len(training_worker_handles),
|
||||
equal=True,
|
||||
locality_hints=training_worker_handles,
|
||||
)
|
||||
|
||||
if isinstance(self.dataset_or_dict, dict):
|
||||
# Return a smaller dict for each shard.
|
||||
dataset_shards = [{} for _ in range(len(training_worker_handles))]
|
||||
for key, dataset in self.dataset_or_dict.items():
|
||||
split_datasets = split_dataset(dataset)
|
||||
assert len(split_datasets) == len(training_worker_handles)
|
||||
for i in range(len(split_datasets)):
|
||||
dataset_shards[i][key] = split_datasets[i]
|
||||
return dataset_shards
|
||||
else:
|
||||
# return a smaller RayDataset for each shard.
|
||||
return split_dataset(self.dataset_or_dict)
|
||||
|
||||
def get_dataset_shards(
|
||||
self, training_worker_handles: List[ActorHandle]
|
||||
) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]:
|
||||
"""Returns Dataset splits based off the spec and the given training workers
|
||||
|
||||
Args:
|
||||
training_worker_handles: A list of the training worker actor handles.
|
||||
|
||||
Returns:
|
||||
A list of RayDataset shards or list of dictionaries of RayDataset shards,
|
||||
one for each training worker.
|
||||
|
||||
"""
|
||||
if not self.dataset_or_dict:
|
||||
return [None] * len(training_worker_handles)
|
||||
|
||||
if self.dataset_split_fn is None:
|
||||
return self._default_split_fn(training_worker_handles)
|
||||
else:
|
||||
splits = self.dataset_split_fn(
|
||||
self.dataset_or_dict, training_worker_handles
|
||||
)
|
||||
if not len(splits) == len(training_worker_handles):
|
||||
raise RuntimeError(
|
||||
"The list of Datasets returned by the "
|
||||
f"`dataset_split_fn`: {len(splits)} does not match "
|
||||
f"the number of training workers: {len(training_worker_handles)}"
|
||||
)
|
||||
return splits
|
|
@ -25,7 +25,8 @@ from ray.train.constants import (
|
|||
RESULT_FETCH_TIMEOUT,
|
||||
SESSION_MISUSE_LOG_ONCE_KEY,
|
||||
)
|
||||
from ray.train.utils import PropagatingThread, RayDataset
|
||||
from ray.train.utils import PropagatingThread
|
||||
from ray.train.impl.dataset_spec import RayDataset
|
||||
from ray.util import PublicAPI, log_once
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from ray.train.backend import (
|
|||
TrainingWorkerError,
|
||||
)
|
||||
from ray.train.backend import BackendConfig, BackendExecutor
|
||||
from ray.train.impl.dataset_spec import _RayDatasetSpec
|
||||
from ray.train.tensorflow import TensorflowConfig
|
||||
from ray.train.torch import TorchConfig
|
||||
from ray.train.constants import (
|
||||
|
@ -102,11 +103,14 @@ class TestBackend(Backend):
|
|||
pass
|
||||
|
||||
|
||||
EMPTY_RAY_DATASET_SPEC = _RayDatasetSpec(dataset_or_dict=None)
|
||||
|
||||
|
||||
def test_start(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
with pytest.raises(InactiveWorkerGroupError):
|
||||
e.start_training(lambda: 1)
|
||||
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
e.start()
|
||||
assert len(e.worker_group) == 2
|
||||
|
||||
|
@ -127,7 +131,7 @@ def test_initialization_hook(ray_start_2_cpus):
|
|||
|
||||
return os.getenv("TEST", "0")
|
||||
|
||||
e.start_training(check)
|
||||
e.start_training(check, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
assert e.finish_training() == ["1", "1"]
|
||||
|
||||
|
||||
|
@ -138,7 +142,7 @@ def test_shutdown(ray_start_2_cpus):
|
|||
assert len(e.worker_group) == 2
|
||||
e.shutdown()
|
||||
with pytest.raises(InactiveWorkerGroupError):
|
||||
e.start_training(lambda: 1)
|
||||
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
|
||||
|
||||
def test_train(ray_start_2_cpus):
|
||||
|
@ -146,7 +150,7 @@ def test_train(ray_start_2_cpus):
|
|||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
|
||||
e.start_training(lambda: 1)
|
||||
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
assert e.finish_training() == [1, 1]
|
||||
|
||||
|
||||
|
@ -158,7 +162,7 @@ def test_local_ranks(ray_start_2_cpus):
|
|||
def train_func():
|
||||
return train.local_rank()
|
||||
|
||||
e.start_training(train_func)
|
||||
e.start_training(train_func, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
assert set(e.finish_training()) == {0, 1}
|
||||
|
||||
|
||||
|
@ -176,10 +180,10 @@ def test_train_failure(ray_start_2_cpus):
|
|||
with pytest.raises(TrainBackendError):
|
||||
e.finish_training()
|
||||
|
||||
e.start_training(lambda: 1)
|
||||
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
|
||||
with pytest.raises(TrainBackendError):
|
||||
e.start_training(lambda: 2)
|
||||
e.start_training(lambda: 2, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
|
||||
assert e.finish_training() == [1, 1]
|
||||
|
||||
|
@ -195,7 +199,7 @@ def test_worker_failure(ray_start_2_cpus):
|
|||
new_execute_func = gen_execute_special(train_fail)
|
||||
with patch.object(WorkerGroup, "execute_async", new_execute_func):
|
||||
with pytest.raises(TrainingWorkerError):
|
||||
e.start_training(lambda: 1)
|
||||
e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
e.finish_training()
|
||||
|
||||
|
||||
|
@ -209,7 +213,7 @@ def test_mismatch_checkpoint_report(ray_start_2_cpus):
|
|||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
e.start_training(train_func)
|
||||
e.start_training(train_func, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
with pytest.raises(RuntimeError):
|
||||
e.get_next_results()
|
||||
|
||||
|
@ -226,7 +230,7 @@ def test_tensorflow_start(ray_start_2_cpus):
|
|||
|
||||
return json.loads(os.environ["TF_CONFIG"])
|
||||
|
||||
e.start_training(get_tf_config)
|
||||
e.start_training(get_tf_config, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
results = e.finish_training()
|
||||
assert len(results) == num_workers
|
||||
|
||||
|
@ -251,12 +255,12 @@ def test_torch_start_shutdown(ray_start_2_cpus, init_method):
|
|||
and torch.distributed.get_world_size() == 2
|
||||
)
|
||||
|
||||
e.start_training(check_process_group)
|
||||
e.start_training(check_process_group, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
assert all(e.finish_training())
|
||||
|
||||
e._backend.on_shutdown(e.worker_group, e._backend_config)
|
||||
|
||||
e.start_training(check_process_group)
|
||||
e.start_training(check_process_group, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
assert not any(e.finish_training())
|
||||
|
||||
|
||||
|
@ -282,7 +286,7 @@ def test_cuda_visible_devices(ray_2_node_2_gpu, worker_results):
|
|||
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=1
|
||||
)
|
||||
e.start()
|
||||
e.start_training(get_resources)
|
||||
e.start_training(get_resources, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
results = e.finish_training()
|
||||
results.sort()
|
||||
assert results == expected_results
|
||||
|
@ -314,7 +318,7 @@ def test_cuda_visible_devices_fractional(ray_2_node_2_gpu, worker_results):
|
|||
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=0.5
|
||||
)
|
||||
e.start()
|
||||
e.start_training(get_resources)
|
||||
e.start_training(get_resources, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
results = e.finish_training()
|
||||
results.sort()
|
||||
assert results == expected_results
|
||||
|
@ -342,7 +346,7 @@ def test_cuda_visible_devices_multiple(ray_2_node_4_gpu, worker_results):
|
|||
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=2
|
||||
)
|
||||
e.start()
|
||||
e.start_training(get_resources)
|
||||
e.start_training(get_resources, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
results = e.finish_training()
|
||||
results.sort()
|
||||
assert results == expected_results
|
||||
|
@ -393,7 +397,7 @@ def test_placement_group_parent(ray_4_node_4_cpu, placement_group_capture_child_
|
|||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
e.start_training(train_func)
|
||||
e.start_training(train_func, dataset_spec=EMPTY_RAY_DATASET_SPEC)
|
||||
return e.finish_training()
|
||||
|
||||
results_future = test.options(
|
||||
|
|
|
@ -16,8 +16,12 @@ from ray.train.backend import (
|
|||
TrainingWorkerError,
|
||||
)
|
||||
from ray.train.callbacks.callback import TrainingCallback
|
||||
from ray.train.impl.dataset_spec import RayDataset, _RayDatasetSpec
|
||||
from ray.train.session import TrainingResultType
|
||||
from ray.train.utils import RayDataset, construct_train_func, ActorWrapper
|
||||
from ray.train.utils import (
|
||||
construct_train_func,
|
||||
ActorWrapper,
|
||||
)
|
||||
from ray.train.checkpoint import (
|
||||
CheckpointStrategy,
|
||||
TuneCheckpointManager,
|
||||
|
@ -336,12 +340,14 @@ class Trainer:
|
|||
|
||||
train_func = construct_train_func(train_func, config)
|
||||
|
||||
dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset)
|
||||
|
||||
try:
|
||||
iterator = TrainingIterator(
|
||||
backend_executor=self._backend_executor,
|
||||
backend_config=self._backend_config,
|
||||
train_func=train_func,
|
||||
dataset=dataset,
|
||||
dataset_spec=dataset_spec,
|
||||
checkpoint_manager=self.checkpoint_manager,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
|
@ -413,12 +419,14 @@ class Trainer:
|
|||
|
||||
train_func = construct_train_func(train_func, config)
|
||||
|
||||
dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset)
|
||||
|
||||
return TrainingIterator(
|
||||
backend_executor=self._backend_executor,
|
||||
backend_config=self._backend_config,
|
||||
train_func=train_func,
|
||||
run_dir=self.latest_run_dir,
|
||||
dataset=dataset,
|
||||
dataset_spec=dataset_spec,
|
||||
checkpoint_manager=self.checkpoint_manager,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
|
@ -513,7 +521,7 @@ class Trainer:
|
|||
"""Creates a Tune ``Trainable`` from the input training function.
|
||||
|
||||
Args:
|
||||
func (Callable): The function that should be executed on each
|
||||
train_func (Callable): The function that should be executed on each
|
||||
training worker.
|
||||
dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]):
|
||||
Distributed Ray p:ref:`Dataset <dataset-api>` or
|
||||
|
@ -650,7 +658,7 @@ class TrainingIterator:
|
|||
backend_executor: Union[BackendExecutor, ActorWrapper],
|
||||
backend_config: BackendConfig,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]],
|
||||
dataset_spec: _RayDatasetSpec,
|
||||
checkpoint_manager: CheckpointManager,
|
||||
checkpoint: Optional[Union[Dict, str, Path]],
|
||||
checkpoint_strategy: Optional[CheckpointStrategy],
|
||||
|
@ -659,14 +667,14 @@ class TrainingIterator:
|
|||
self._backend_executor = backend_executor
|
||||
self._backend = backend_config.backend_cls()
|
||||
self._train_func = train_func
|
||||
self._dataset = dataset
|
||||
self._dataset_spec = dataset_spec
|
||||
self._run_dir = run_dir
|
||||
self._checkpoint_manager = checkpoint_manager
|
||||
self._checkpoint_strategy = checkpoint_strategy
|
||||
self._start_training(
|
||||
train_func=train_func,
|
||||
run_dir=run_dir,
|
||||
dataset=dataset,
|
||||
dataset_spec=self._dataset_spec,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
)
|
||||
|
@ -681,7 +689,7 @@ class TrainingIterator:
|
|||
self,
|
||||
train_func,
|
||||
run_dir,
|
||||
dataset,
|
||||
dataset_spec,
|
||||
checkpoint,
|
||||
checkpoint_strategy,
|
||||
latest_checkpoint_id=None,
|
||||
|
@ -694,7 +702,9 @@ class TrainingIterator:
|
|||
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
|
||||
self._run_with_error_handling(
|
||||
lambda: self._backend_executor.start_training(
|
||||
train_func=train_func, dataset=dataset, checkpoint=checkpoint_dict
|
||||
train_func=train_func,
|
||||
dataset_spec=dataset_spec,
|
||||
checkpoint=checkpoint_dict,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -713,7 +723,7 @@ class TrainingIterator:
|
|||
self._start_training(
|
||||
self._train_func,
|
||||
self._run_dir,
|
||||
self._dataset,
|
||||
self._dataset_spec,
|
||||
self._checkpoint_manager.latest_checkpoint,
|
||||
self._checkpoint_strategy,
|
||||
latest_checkpoint_id=self._checkpoint_manager.latest_checkpoint_id,
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import (
|
|||
Dict,
|
||||
List,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
Callable,
|
||||
TypeVar,
|
||||
|
@ -23,11 +22,6 @@ from ray.exceptions import RayActorError
|
|||
from ray.types import ObjectRef
|
||||
from ray.util.ml_utils.util import find_free_port
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset
|
||||
from ray.data.dataset_pipeline import DatasetPipeline
|
||||
|
||||
RayDataset = Union["Dataset", "DatasetPipeline"]
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
Loading…
Add table
Reference in a new issue