mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[train] wrap BackendExecutor in ray.remote() (#20123)
* [train] wrap BackendExecutor in ray.remote() * wip * fix trainer tests * move CheckpointManager to Trainer * [tune] move force_on_current_node to ml_utils * fix import * force on head node * init ray * split test files * update example * move tests to ray client * address comments * move comment * address comments
This commit is contained in:
parent
9e2bd508d7
commit
e22632dabc
13 changed files with 422 additions and 452 deletions
|
@ -373,7 +373,10 @@ def train_func(config):
|
|||
|
||||
# Checkpoint model.
|
||||
if is_distributed:
|
||||
train.save_checkpoint(model_state_dict=net.module.state_dict())
|
||||
import copy
|
||||
model_copy = copy.deepcopy(net.module)
|
||||
train.save_checkpoint(
|
||||
model_state_dict=model_copy.cpu().state_dict())
|
||||
else:
|
||||
torch.save(net.state_dict(), f"models/model-epoch-{epoch}.torch")
|
||||
|
||||
|
@ -386,7 +389,7 @@ def train_func(config):
|
|||
|
||||
if is_distributed:
|
||||
if train.world_rank() == 0:
|
||||
return net.module
|
||||
return net.module.cpu()
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
|
|
|
@ -93,6 +93,13 @@ py_test(
|
|||
deps = [":train_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_examples",
|
||||
size = "large",
|
||||
srcs = ["tests/test_examples.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":train_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_gpu",
|
||||
|
|
|
@ -8,13 +8,10 @@ from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple
|
|||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.ray_constants import env_integer
|
||||
from ray.train.checkpoint import CheckpointManager, CheckpointStrategy, \
|
||||
TuneCheckpointManager
|
||||
from ray.train.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
|
||||
TUNE_INSTALLED, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV, \
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV
|
||||
from ray.train.session import TrainingResultType, TrainingResult
|
||||
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, TRAIN_ENABLE_WORKER_SPREAD_ENV
|
||||
from ray.train.session import TrainingResult
|
||||
from ray.train.session import init_session, get_session, shutdown_session
|
||||
from ray.train.utils import RayDataset
|
||||
from ray.train.utils import check_for_failure
|
||||
|
@ -22,11 +19,6 @@ from ray.train.worker_group import WorkerGroup
|
|||
from ray.util.placement_group import get_current_placement_group, \
|
||||
remove_placement_group
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
else:
|
||||
tune = None
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -67,15 +59,6 @@ class BackendExecutor:
|
|||
and ``num_gpus_per_worker``.
|
||||
max_retries (int): Number of retries when Ray actors fail.
|
||||
Defaults to 3. Set to -1 for unlimited retries.
|
||||
|
||||
Attributes:
|
||||
latest_checkpoint_dir (Optional[Path]): Path to the file directory for
|
||||
the checkpoints from the latest run. Configured through
|
||||
``start_training``
|
||||
best_checkpoint_path (Optional[Path]): Path to the best persisted
|
||||
checkpoint from the latest run.
|
||||
latest_checkpoint (Optional[Dict]): The latest saved checkpoint. This
|
||||
checkpoint may not be saved to disk.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -99,16 +82,9 @@ class BackendExecutor:
|
|||
self._initialization_hook = None
|
||||
self._placement_group = None
|
||||
|
||||
if tune is not None and tune.is_session_enabled():
|
||||
self.checkpoint_manager = TuneCheckpointManager()
|
||||
else:
|
||||
self.checkpoint_manager = CheckpointManager()
|
||||
|
||||
self.worker_group = InactiveWorkerGroup()
|
||||
self.dataset_shards = None
|
||||
|
||||
self.checkpoint_manager.on_init()
|
||||
|
||||
def start(self,
|
||||
initialization_hook: Optional[Callable[[], None]] = None,
|
||||
train_cls: Optional[Type] = None,
|
||||
|
@ -304,10 +280,7 @@ class BackendExecutor:
|
|||
train_func: Callable[[], T],
|
||||
run_dir: Path,
|
||||
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
|
||||
checkpoint: Optional[Union[Dict, str, Path]] = None,
|
||||
checkpoint_strategy: Optional[CheckpointStrategy] = None,
|
||||
latest_checkpoint_id: Optional[int] = None,
|
||||
) -> None:
|
||||
checkpoint: Optional[Dict] = None) -> None:
|
||||
"""Executes a training function on all workers in a separate thread.
|
||||
|
||||
``finish_training`` should be called after this.
|
||||
|
@ -324,22 +297,11 @@ class BackendExecutor:
|
|||
and each Dataset can be accessed from the training function
|
||||
by passing in a `dataset_name` argument to
|
||||
``train.get_dataset_shard()``.
|
||||
checkpoint (Optional[Dict|str|Path]): The checkpoint data that
|
||||
checkpoint (Optional[Dict]): The checkpoint data that
|
||||
should be loaded onto each worker and accessed by the
|
||||
training function via ``train.load_checkpoint()``. If this is a
|
||||
``str`` or ``Path`` then the value is expected to be a path
|
||||
to a file that contains a serialized checkpoint dict. If this
|
||||
training function via ``train.load_checkpoint()``. If this
|
||||
is ``None`` then no checkpoint will be loaded.
|
||||
checkpoint_strategy (Optional[CheckpointStrategy]): The
|
||||
configurations for saving checkpoints.
|
||||
latest_checkpoint_id (Optional[int]): The checkpoint id of the
|
||||
most recently saved checkpoint.
|
||||
"""
|
||||
self.checkpoint_manager.on_start_training(
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
run_dir=run_dir,
|
||||
latest_checkpoint_id=latest_checkpoint_id)
|
||||
|
||||
use_detailed_autofilled_metrics = env_integer(
|
||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
|
||||
|
||||
|
@ -365,8 +327,6 @@ class BackendExecutor:
|
|||
if self.dataset_shards is None:
|
||||
self.dataset_shards = self._get_dataset_shards(dataset)
|
||||
|
||||
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
|
||||
|
||||
local_rank_map = self._create_local_rank_map()
|
||||
|
||||
futures = []
|
||||
|
@ -379,7 +339,7 @@ class BackendExecutor:
|
|||
local_rank=local_rank_map[index],
|
||||
train_func=train_func,
|
||||
dataset_shard=self.dataset_shards[index],
|
||||
checkpoint=checkpoint_dict))
|
||||
checkpoint=checkpoint))
|
||||
|
||||
self.get_with_failure_handling(futures)
|
||||
|
||||
|
@ -390,7 +350,7 @@ class BackendExecutor:
|
|||
|
||||
self.worker_group.execute_async(train_async)
|
||||
|
||||
def _get_next_results(self) -> Optional[List[TrainingResult]]:
|
||||
def get_next_results(self) -> Optional[List[TrainingResult]]:
|
||||
"""Fetches the next ``TrainingResult`` from each worker.
|
||||
|
||||
Each ``TrainingResult`` is expected to correspond to the same step from
|
||||
|
@ -403,24 +363,15 @@ class BackendExecutor:
|
|||
"""
|
||||
|
||||
def get_next():
|
||||
# Get the session for this worker.
|
||||
try:
|
||||
session = get_session()
|
||||
except ValueError:
|
||||
# Session is not initialized yet.
|
||||
raise TrainBackendError("`fetch_next_result` has been called "
|
||||
"before `start_training`. Please call "
|
||||
"`start_training` before "
|
||||
"`fetch_next_result`.")
|
||||
|
||||
session = _get_session("get_next_results")
|
||||
try:
|
||||
result = session.get_next()
|
||||
except RuntimeError:
|
||||
# Training thread has not been started yet.
|
||||
raise TrainBackendError("`fetch_next_result` has been called "
|
||||
raise TrainBackendError("`get_next_results` has been called "
|
||||
"before `start_training`. Please call "
|
||||
"`start_training` before "
|
||||
"`fetch_next_result`.")
|
||||
"`get_next_results`.")
|
||||
|
||||
return result
|
||||
|
||||
|
@ -451,37 +402,21 @@ class BackendExecutor:
|
|||
"each worker.")
|
||||
return results
|
||||
|
||||
def fetch_next_result(self) -> Optional[List[Dict]]:
|
||||
"""Fetch next results produced by ``train.report()`` from each worker.
|
||||
def pause_reporting(self):
|
||||
""" Disable workers from enqueuing results from `train.report()`.
|
||||
|
||||
Assumes ``start_training`` has already been called.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries of values passed to ``train.report()`` from
|
||||
each worker. Each item corresponds to an intermediate result
|
||||
a single worker. If there are no more items to fetch,
|
||||
returns None.
|
||||
Note: Already reported results may still be enqueued at this point,
|
||||
and should be handled appropriately.
|
||||
"""
|
||||
|
||||
while True:
|
||||
results = self._get_next_results()
|
||||
if results is None:
|
||||
return None
|
||||
first_result = results[0]
|
||||
result_type = first_result.type
|
||||
if result_type is TrainingResultType.REPORT:
|
||||
result_data = [r.data for r in results]
|
||||
return result_data
|
||||
elif result_type is TrainingResultType.CHECKPOINT:
|
||||
self.checkpoint_manager._process_checkpoint(results)
|
||||
# Iterate until next REPORT call or training has finished.
|
||||
else:
|
||||
raise TrainBackendError(f"Unexpected result type: "
|
||||
f"{result_type}. "
|
||||
f"Expected one of "
|
||||
f"{[type in TrainingResultType]}")
|
||||
def pause_session_reporting():
|
||||
session = _get_session("pause_reporting")
|
||||
return session.pause_reporting()
|
||||
|
||||
def finish_training(self) -> List[T]:
|
||||
futures = self.worker_group.execute_async(pause_session_reporting)
|
||||
self.get_with_failure_handling(futures)
|
||||
|
||||
def finish_training(self):
|
||||
"""Finish training and return final results. Propagate any exceptions.
|
||||
|
||||
Blocks until training is finished on all workers.
|
||||
|
@ -493,30 +428,8 @@ class BackendExecutor:
|
|||
Each item corresponds to the return value from a single worker.
|
||||
"""
|
||||
|
||||
def pause_reporting():
|
||||
# Get the session for this worker.
|
||||
try:
|
||||
session = get_session()
|
||||
except ValueError:
|
||||
# Session is not initialized yet.
|
||||
raise TrainBackendError("`finish_training` has been called "
|
||||
"before `start_training`. Please call "
|
||||
"`start_training` before "
|
||||
"`finish_training`.")
|
||||
|
||||
return session.pause_reporting()
|
||||
|
||||
def end_training():
|
||||
# Get the session for this worker.
|
||||
try:
|
||||
session = get_session()
|
||||
except ValueError:
|
||||
# Session is not initialized yet.
|
||||
raise TrainBackendError("`finish_training` has been called "
|
||||
"before `start_training`. Please call "
|
||||
"`start_training` before "
|
||||
"`finish_training`.")
|
||||
|
||||
session = _get_session("finish_training")
|
||||
try:
|
||||
# session.finish raises any Exceptions from training.
|
||||
output = session.finish()
|
||||
|
@ -527,23 +440,6 @@ class BackendExecutor:
|
|||
|
||||
return output
|
||||
|
||||
# Disable workers from enqueuing results from `train.report()`.
|
||||
# Results will not be processed during the execution of `finish`.
|
||||
# Note: Reported results may still be enqueued at this point,
|
||||
# and should be handled appropriately.
|
||||
futures = self.worker_group.execute_async(pause_reporting)
|
||||
self.get_with_failure_handling(futures)
|
||||
|
||||
# Finish up processing checkpoints. Reporting has been disabled.
|
||||
while True:
|
||||
results = self._get_next_results()
|
||||
if results is None:
|
||||
break
|
||||
result_type = results[0].type
|
||||
# Process checkpoints and ignore other result types.
|
||||
if result_type is TrainingResultType.CHECKPOINT:
|
||||
self.checkpoint_manager._process_checkpoint(results)
|
||||
|
||||
futures = self.worker_group.execute_async(end_training)
|
||||
results = self.get_with_failure_handling(futures)
|
||||
return results
|
||||
|
@ -594,37 +490,9 @@ class BackendExecutor:
|
|||
|
||||
self.dataset_shards = None
|
||||
|
||||
@property
|
||||
def is_started(self):
|
||||
return not isinstance(self.worker_group, InactiveWorkerGroup)
|
||||
|
||||
@property
|
||||
def latest_checkpoint_dir(self) -> Optional[Path]:
|
||||
"""Path to the latest checkpoint directory."""
|
||||
return self.checkpoint_manager.latest_checkpoint_dir
|
||||
|
||||
@property
|
||||
def best_checkpoint_path(self) -> Optional[Path]:
|
||||
"""Path to the best persisted checkpoint."""
|
||||
return self.checkpoint_manager.best_checkpoint_path
|
||||
|
||||
@property
|
||||
def latest_checkpoint_id(self) -> Optional[int]:
|
||||
"""The checkpoint id of most recently saved checkpoint.
|
||||
|
||||
If no checkpoint has been saved yet, then return None.
|
||||
"""
|
||||
checkpoint_id = self.checkpoint_manager._latest_checkpoint_id
|
||||
if checkpoint_id == 0:
|
||||
return None
|
||||
else:
|
||||
return checkpoint_id
|
||||
|
||||
@property
|
||||
def latest_checkpoint(self) -> Optional[Dict]:
|
||||
"""Latest checkpoint object."""
|
||||
return self.checkpoint_manager.latest_checkpoint
|
||||
|
||||
def _restart(self):
|
||||
self.worker_group.shutdown()
|
||||
if self._initialization_hook is not None:
|
||||
|
@ -646,6 +514,12 @@ class BackendExecutor:
|
|||
"`max_retries` arg in your `Trainer`.") \
|
||||
from None
|
||||
|
||||
def get_worker_group(self):
|
||||
return self.worker_group
|
||||
|
||||
def _get_num_failures(self):
|
||||
return self._num_failures
|
||||
|
||||
|
||||
class Backend(metaclass=abc.ABCMeta):
|
||||
"""Metaclass for distributed communication backend.
|
||||
|
@ -702,3 +576,15 @@ class InactiveWorkerGroup():
|
|||
|
||||
def __len__(self):
|
||||
raise InactiveWorkerGroupError()
|
||||
|
||||
|
||||
def _get_session(method_name: str):
|
||||
try:
|
||||
# Get the session for this worker.
|
||||
return get_session()
|
||||
except ValueError:
|
||||
# Session is not initialized yet.
|
||||
raise TrainBackendError(f"`{method_name}` has been called "
|
||||
"before `start_training`. Please call "
|
||||
"`start_training` before "
|
||||
f"`{method_name}`.")
|
||||
|
|
|
@ -274,6 +274,18 @@ class CheckpointManager:
|
|||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def latest_checkpoint_id(self) -> Optional[int]:
|
||||
"""The checkpoint id of most recently saved checkpoint.
|
||||
|
||||
If no checkpoint has been saved yet, then return None.
|
||||
"""
|
||||
checkpoint_id = self._latest_checkpoint_id
|
||||
if checkpoint_id == 0:
|
||||
return None
|
||||
else:
|
||||
return checkpoint_id
|
||||
|
||||
|
||||
class TuneCheckpointManager(CheckpointManager):
|
||||
def create_logdir(self, log_dir: Optional[Union[str, Path]]):
|
||||
|
|
|
@ -49,7 +49,9 @@ def validate_epoch(dataloader, model, loss_fn, device):
|
|||
pred = model(X)
|
||||
loss += loss_fn(pred, y).item()
|
||||
loss /= num_batches
|
||||
result = {"model": model.state_dict(), "loss": loss}
|
||||
import copy
|
||||
model_copy = copy.deepcopy(model)
|
||||
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
@ -160,7 +159,10 @@ def test_train_failure(ray_start_2_cpus, tmp_path):
|
|||
e.start()
|
||||
|
||||
with pytest.raises(TrainBackendError):
|
||||
e.fetch_next_result()
|
||||
e.get_next_results()
|
||||
|
||||
with pytest.raises(TrainBackendError):
|
||||
e.pause_reporting()
|
||||
|
||||
with pytest.raises(TrainBackendError):
|
||||
e.finish_training()
|
||||
|
@ -188,91 +190,6 @@ def test_worker_failure(ray_start_2_cpus, tmp_path):
|
|||
e.finish_training()
|
||||
|
||||
|
||||
def test_no_exhaust(ray_start_2_cpus, tmp_path):
|
||||
"""Tests if training can finish even if queue is not exhausted."""
|
||||
|
||||
def train_func():
|
||||
for _ in range(2):
|
||||
train.report(loss=1)
|
||||
return 2
|
||||
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
|
||||
e.start_training(train_func, run_dir=tmp_path)
|
||||
output = e.finish_training()
|
||||
|
||||
assert output == [2, 2]
|
||||
|
||||
|
||||
def test_checkpoint(ray_start_2_cpus, tmp_path):
|
||||
def train_func():
|
||||
for i in range(2):
|
||||
train.save_checkpoint(epoch=i)
|
||||
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=1)
|
||||
e.start()
|
||||
|
||||
e.start_training(train_func, run_dir=tmp_path)
|
||||
e.finish_training()
|
||||
|
||||
assert e.latest_checkpoint is not None
|
||||
assert e.latest_checkpoint["epoch"] == 1
|
||||
|
||||
|
||||
def test_persisted_checkpoint(ray_start_2_cpus, tmp_path):
|
||||
def train_func():
|
||||
for i in range(2):
|
||||
train.save_checkpoint(epoch=i)
|
||||
time.sleep(1)
|
||||
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config)
|
||||
e.start()
|
||||
e.start_training(train_func, run_dir=tmp_path)
|
||||
e.finish_training()
|
||||
|
||||
assert e.latest_checkpoint_id == 2
|
||||
assert e.latest_checkpoint is not None
|
||||
assert e.latest_checkpoint["epoch"] == 1
|
||||
assert e.best_checkpoint_path is not None
|
||||
|
||||
assert os.path.exists(e.best_checkpoint_path)
|
||||
|
||||
def validate():
|
||||
checkpoint = train.load_checkpoint()
|
||||
assert checkpoint is not None
|
||||
assert checkpoint["epoch"] == 1
|
||||
|
||||
e2 = BackendExecutor(config)
|
||||
e2.start()
|
||||
e2.start_training(
|
||||
validate, checkpoint=e.best_checkpoint_path, run_dir=tmp_path)
|
||||
e2.finish_training()
|
||||
|
||||
|
||||
def test_persisted_checkpoint_id(ray_start_2_cpus, tmp_path):
|
||||
def train_func():
|
||||
for i in range(2):
|
||||
train.save_checkpoint(epoch=i)
|
||||
time.sleep(1)
|
||||
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config)
|
||||
e.start()
|
||||
e.start_training(train_func, run_dir=tmp_path, latest_checkpoint_id=100)
|
||||
e.finish_training()
|
||||
|
||||
assert e.latest_checkpoint_id == 102
|
||||
assert e.latest_checkpoint is not None
|
||||
assert e.latest_checkpoint["epoch"] == 1
|
||||
assert e.best_checkpoint_path is not None
|
||||
|
||||
assert os.path.exists(e.best_checkpoint_path)
|
||||
|
||||
|
||||
def test_mismatch_checkpoint_report(ray_start_2_cpus, tmp_path):
|
||||
def train_func():
|
||||
if (train.world_rank()) == 0:
|
||||
|
@ -285,7 +202,7 @@ def test_mismatch_checkpoint_report(ray_start_2_cpus, tmp_path):
|
|||
e.start()
|
||||
e.start_training(train_func, run_dir=tmp_path)
|
||||
with pytest.raises(RuntimeError):
|
||||
e.finish_training()
|
||||
e.get_next_results()
|
||||
|
||||
|
||||
def test_tensorflow_start(ray_start_2_cpus, tmp_path):
|
||||
|
|
145
python/ray/train/tests/test_examples.py
Normal file
145
python/ray/train/tests/test_examples.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.train import Trainer
|
||||
from ray.train.examples.horovod.horovod_example import train_func as \
|
||||
horovod_torch_train_func, HorovodTrainClass
|
||||
from ray.train.examples.tensorflow_mnist_example import train_func as \
|
||||
tensorflow_mnist_train_func
|
||||
from ray.train.examples.tensorflow_quick_start import train_func as \
|
||||
tf_quick_start_train_func
|
||||
from ray.train.examples.torch_quick_start import train_func as \
|
||||
torch_quick_start_train_func
|
||||
from ray.train.examples.train_fashion_mnist_example import train_func \
|
||||
as fashion_mnist_train_func
|
||||
from ray.train.examples.train_linear_example import train_func as \
|
||||
linear_train_func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2])
|
||||
def test_tensorflow_mnist(ray_start_2_cpus, num_workers):
|
||||
num_workers = num_workers
|
||||
epochs = 3
|
||||
|
||||
trainer = Trainer("tensorflow", num_workers=num_workers)
|
||||
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
|
||||
trainer.start()
|
||||
results = trainer.run(tensorflow_mnist_train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
result = results[0]
|
||||
|
||||
loss = result["loss"]
|
||||
assert len(loss) == epochs
|
||||
assert loss[-1] < loss[0]
|
||||
|
||||
accuracy = result["accuracy"]
|
||||
assert len(accuracy) == epochs
|
||||
assert accuracy[-1] > accuracy[0]
|
||||
|
||||
|
||||
def test_tf_non_distributed(ray_start_2_cpus):
|
||||
"""Make sure Ray Train works without TF MultiWorkerMirroredStrategy."""
|
||||
|
||||
trainer = Trainer(backend="torch", num_workers=1)
|
||||
trainer.start()
|
||||
trainer.run(tf_quick_start_train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2])
|
||||
def test_torch_linear(ray_start_2_cpus, num_workers):
|
||||
num_workers = num_workers
|
||||
epochs = 3
|
||||
|
||||
trainer = Trainer("torch", num_workers=num_workers)
|
||||
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
|
||||
trainer.start()
|
||||
results = trainer.run(linear_train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
|
||||
for result in results:
|
||||
assert len(result) == epochs
|
||||
assert result[-1]["loss"] < result[0]["loss"]
|
||||
|
||||
|
||||
def test_torch_fashion_mnist(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
epochs = 3
|
||||
|
||||
trainer = Trainer("torch", num_workers=num_workers)
|
||||
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
|
||||
trainer.start()
|
||||
results = trainer.run(fashion_mnist_train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
|
||||
for result in results:
|
||||
assert len(result) == epochs
|
||||
assert result[-1] < result[0]
|
||||
|
||||
|
||||
def test_torch_non_distributed(ray_start_2_cpus):
|
||||
"""Make sure Ray Train works without torch DDP."""
|
||||
|
||||
trainer = Trainer(backend="torch", num_workers=1)
|
||||
trainer.start()
|
||||
trainer.run(torch_quick_start_train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_horovod_torch_mnist(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
trainer.start()
|
||||
results = trainer.run(
|
||||
horovod_torch_train_func,
|
||||
config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
for worker_result in results:
|
||||
assert len(worker_result) == num_epochs
|
||||
assert worker_result[num_epochs - 1] < worker_result[0]
|
||||
|
||||
|
||||
def test_horovod_torch_mnist_stateful(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
workers = trainer.to_worker_group(
|
||||
HorovodTrainClass, config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
results = []
|
||||
for epoch in range(num_epochs):
|
||||
results.append(ray.get([w.train.remote(epoch=epoch) for w in workers]))
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_epochs
|
||||
for i in range(num_workers):
|
||||
assert results[num_epochs - 1][i] < results[0][i]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
|
@ -16,19 +16,6 @@ from ray.train.torch import TorchConfig
|
|||
from ray.train.tensorflow import TensorflowConfig
|
||||
from ray.train.horovod import HorovodConfig
|
||||
from ray.train.callbacks.callback import TrainingCallback
|
||||
from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
|
||||
from ray.train.examples.horovod.horovod_example import train_func as \
|
||||
horovod_torch_train_func, HorovodTrainClass
|
||||
from ray.train.examples.tensorflow_mnist_example import train_func as \
|
||||
tensorflow_mnist_train_func
|
||||
from ray.train.examples.train_fashion_mnist_example import train_func \
|
||||
as fashion_mnist_train_func
|
||||
from ray.train.examples.train_linear_example import train_func as \
|
||||
linear_train_func
|
||||
from ray.train.examples.torch_quick_start import train_func as \
|
||||
torch_quick_start_train_func
|
||||
from ray.train.examples.tensorflow_quick_start import train_func as \
|
||||
tf_quick_start_train_func
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
|
||||
|
||||
|
@ -129,10 +116,11 @@ def gen_new_backend_executor(special_f):
|
|||
|
||||
|
||||
class KillCallback(TrainingCallback):
|
||||
def __init__(self, fail_on, worker_group):
|
||||
def __init__(self, fail_on, trainer):
|
||||
self.counter = 0
|
||||
self.fail_on = fail_on
|
||||
self.worker_group = worker_group
|
||||
self.worker_group = ray.get(
|
||||
trainer._backend_executor_actor.get_worker_group.remote())
|
||||
|
||||
def handle_result(self, results):
|
||||
print(results)
|
||||
|
@ -333,6 +321,24 @@ def test_run_iterator_error(ray_start_2_cpus):
|
|||
assert iterator.is_finished()
|
||||
|
||||
|
||||
def test_no_exhaust(ray_start_2_cpus, tmp_path):
|
||||
"""Tests if training can finish even if queue is not exhausted."""
|
||||
|
||||
def train_func():
|
||||
for _ in range(2):
|
||||
train.report(loss=1)
|
||||
return 2
|
||||
|
||||
config = TestConfig()
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
|
||||
iterator = trainer.run_iterator(train_func)
|
||||
output = iterator.get_final_results(force=True)
|
||||
|
||||
assert output == [2, 2]
|
||||
|
||||
|
||||
def test_checkpoint(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
|
@ -556,82 +562,6 @@ def test_world_rank(ray_start_2_cpus):
|
|||
assert set(results) == {0, 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2])
|
||||
def test_tensorflow_mnist(ray_start_2_cpus, num_workers):
|
||||
num_workers = num_workers
|
||||
epochs = 3
|
||||
|
||||
trainer = Trainer("tensorflow", num_workers=num_workers)
|
||||
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
|
||||
trainer.start()
|
||||
results = trainer.run(tensorflow_mnist_train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
result = results[0]
|
||||
|
||||
loss = result["loss"]
|
||||
assert len(loss) == epochs
|
||||
assert loss[-1] < loss[0]
|
||||
|
||||
accuracy = result["accuracy"]
|
||||
assert len(accuracy) == epochs
|
||||
assert accuracy[-1] > accuracy[0]
|
||||
|
||||
|
||||
def test_tf_non_distributed(ray_start_2_cpus):
|
||||
"""Make sure Ray Train works without TF MultiWorkerMirroredStrategy."""
|
||||
|
||||
trainer = Trainer(backend="torch", num_workers=1)
|
||||
trainer.start()
|
||||
trainer.run(tf_quick_start_train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2])
|
||||
def test_torch_linear(ray_start_2_cpus, num_workers):
|
||||
num_workers = num_workers
|
||||
epochs = 3
|
||||
|
||||
trainer = Trainer("torch", num_workers=num_workers)
|
||||
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
|
||||
trainer.start()
|
||||
results = trainer.run(linear_train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
|
||||
for result in results:
|
||||
assert len(result) == epochs
|
||||
assert result[-1]["loss"] < result[0]["loss"]
|
||||
|
||||
|
||||
def test_torch_fashion_mnist(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
epochs = 3
|
||||
|
||||
trainer = Trainer("torch", num_workers=num_workers)
|
||||
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
|
||||
trainer.start()
|
||||
results = trainer.run(fashion_mnist_train_func, config)
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
|
||||
for result in results:
|
||||
assert len(result) == epochs
|
||||
assert result[-1] < result[0]
|
||||
|
||||
|
||||
def test_torch_non_distributed(ray_start_2_cpus):
|
||||
"""Make sure Ray Train works without torch DDP."""
|
||||
|
||||
trainer = Trainer(backend="torch", num_workers=1)
|
||||
trainer.start()
|
||||
trainer.run(torch_quick_start_train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_horovod_simple(ray_start_2_cpus):
|
||||
def simple_fn():
|
||||
hvd_torch.init()
|
||||
|
@ -646,44 +576,6 @@ def test_horovod_simple(ray_start_2_cpus):
|
|||
assert result == list(range(num_workers))
|
||||
|
||||
|
||||
def test_horovod_torch_mnist(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
trainer.start()
|
||||
results = trainer.run(
|
||||
horovod_torch_train_func,
|
||||
config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_workers
|
||||
for worker_result in results:
|
||||
assert len(worker_result) == num_epochs
|
||||
assert worker_result[num_epochs - 1] < worker_result[0]
|
||||
|
||||
|
||||
def test_horovod_torch_mnist_stateful(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
workers = trainer.to_worker_group(
|
||||
HorovodTrainClass, config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
results = []
|
||||
for epoch in range(num_epochs):
|
||||
results.append(ray.get([w.train.remote(epoch=epoch) for w in workers]))
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_epochs
|
||||
for i in range(num_workers):
|
||||
assert results[num_epochs - 1][i] < results[0][i]
|
||||
|
||||
|
||||
def test_init_failure(ray_start_2_cpus):
|
||||
with pytest.raises(TypeError):
|
||||
Trainer(5, num_workers=2)
|
||||
|
@ -809,23 +701,27 @@ def test_worker_failure_local_rank(ray_start_2_cpus):
|
|||
def test_worker_start_failure(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
|
||||
restart = trainer._executor._restart
|
||||
|
||||
def init_hook():
|
||||
pass
|
||||
|
||||
def init_hook_fail():
|
||||
ray.actor.exit_actor()
|
||||
|
||||
def restart_patched(self):
|
||||
self._initialization_hook = init_hook
|
||||
restart()
|
||||
class TestBackendExecutor(BackendExecutor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with patch.object(BackendExecutor, "_restart", restart_patched):
|
||||
def _restart(self):
|
||||
self._initialization_hook = init_hook
|
||||
super()._restart()
|
||||
|
||||
with patch.object(ray.train.trainer, "BackendExecutor",
|
||||
TestBackendExecutor):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start(initialization_hook=init_hook_fail)
|
||||
assert len(trainer._executor.worker_group) == 2
|
||||
assert len(
|
||||
ray.get(trainer._backend_executor_actor.get_worker_group.remote())
|
||||
) == 2
|
||||
|
||||
|
||||
def test_max_failures(ray_start_2_cpus):
|
||||
|
@ -840,7 +736,8 @@ def test_max_failures(ray_start_2_cpus):
|
|||
iterator = trainer.run_iterator(train_func)
|
||||
with pytest.raises(RuntimeError):
|
||||
iterator.get_final_results(force=True)
|
||||
assert iterator._executor._num_failures == 3
|
||||
assert ray.get(
|
||||
iterator._backend_executor_actor._get_num_failures.remote()) == 3
|
||||
|
||||
|
||||
def test_start_max_failures(ray_start_2_cpus):
|
||||
|
@ -874,8 +771,7 @@ def test_worker_kill(ray_start_2_cpus, backend):
|
|||
train.report(loss=1, iter=i)
|
||||
|
||||
trainer.start()
|
||||
kill_callback = KillCallback(
|
||||
fail_on=0, worker_group=trainer._executor.worker_group)
|
||||
kill_callback = KillCallback(fail_on=0, trainer=trainer)
|
||||
trainer.run(train_func, callbacks=[kill_callback])
|
||||
# Run 1: iter=0, counter=1, Successful
|
||||
# Run 2: iter=1, counter=1, Unsuccessful, starts training from beginning
|
||||
|
@ -886,8 +782,7 @@ def test_worker_kill(ray_start_2_cpus, backend):
|
|||
trainer.shutdown()
|
||||
trainer.start()
|
||||
|
||||
kill_callback = KillCallback(
|
||||
fail_on=1, worker_group=trainer._executor.worker_group)
|
||||
kill_callback = KillCallback(fail_on=1, trainer=trainer)
|
||||
trainer.run(train_func, callbacks=[kill_callback])
|
||||
# Run 1: iter=0, counter=1, Successful
|
||||
# Run 2: iter=1, counter=2, Successful
|
||||
|
@ -919,8 +814,7 @@ def test_worker_kill_checkpoint(ray_start_2_cpus):
|
|||
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
kill_callback = KillCallback(
|
||||
fail_on=0, worker_group=trainer._executor.worker_group)
|
||||
kill_callback = KillCallback(fail_on=0, trainer=trainer)
|
||||
|
||||
trainer.run(train_func, callbacks=[kill_callback])
|
||||
|
||||
|
@ -936,8 +830,7 @@ def test_worker_kill_checkpoint(ray_start_2_cpus):
|
|||
trainer.shutdown()
|
||||
trainer.start()
|
||||
|
||||
kill_callback = KillCallback(
|
||||
fail_on=1, worker_group=trainer._executor.worker_group)
|
||||
kill_callback = KillCallback(fail_on=1, trainer=trainer)
|
||||
trainer.run(train_func, callbacks=[kill_callback])
|
||||
# Run 1: epoch=0, counter=1, Successful
|
||||
# *Checkpoint saved*
|
||||
|
@ -1131,11 +1024,10 @@ def test_dataset_pipeline_shuffle(ray_start_4_cpus):
|
|||
|
||||
def test_dataset_fault_tolerance(ray_start_4_cpus):
|
||||
dataset = ray.data.range(10)
|
||||
dataset_splits = dataset.split(n=2, equal=True)
|
||||
test_config = TestConfig()
|
||||
|
||||
def train_func():
|
||||
return 1
|
||||
return train.get_dataset_shard()
|
||||
|
||||
def train_actor_failure():
|
||||
import sys
|
||||
|
@ -1143,16 +1035,23 @@ def test_dataset_fault_tolerance(ray_start_4_cpus):
|
|||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
|
||||
|
||||
class SingleGetDatasetShardsBackendExecutor(new_backend_executor_cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._has_called_get_dataset_shards = False
|
||||
|
||||
def _get_dataset_shards(self, dataset_or_dict):
|
||||
if self._has_called_get_dataset_shards:
|
||||
raise Exception
|
||||
self._has_called_get_dataset_shards = True
|
||||
return super()._get_dataset_shards(dataset_or_dict)
|
||||
|
||||
with patch.object(ray.train.trainer, "BackendExecutor",
|
||||
new_backend_executor_cls):
|
||||
with patch.object(
|
||||
new_backend_executor_cls,
|
||||
"_get_dataset_shards",
|
||||
return_value=dataset_splits) as mock_method:
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
trainer.run(train_func, dataset=dataset)
|
||||
mock_method.assert_called_once()
|
||||
SingleGetDatasetShardsBackendExecutor):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
trainer.run(train_func, dataset=dataset)
|
||||
# No exception is raised by _get_dataset_shards
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resource", ["CPU", "GPU", "extra"])
|
||||
|
@ -1180,10 +1079,18 @@ def test_resources(ray_start_4_cpus_4_gpus_4_extra, resource, num_requested):
|
|||
|
||||
|
||||
def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
|
||||
class CudaTestBackend(TestBackend):
|
||||
share_cuda_visible_devices = True
|
||||
|
||||
class CudaTestConfig(TestConfig):
|
||||
@property
|
||||
def backend_cls(self):
|
||||
return CudaTestBackend
|
||||
|
||||
# GPUs should not be requested if `use_gpu` is False.
|
||||
with pytest.raises(ValueError):
|
||||
Trainer(
|
||||
TestConfig(),
|
||||
CudaTestConfig(),
|
||||
num_workers=2,
|
||||
use_gpu=False,
|
||||
resources_per_worker={"GPU": 1})
|
||||
|
@ -1191,7 +1098,7 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
|
|||
# GPUs should not be set to 0 if `use_gpu` is True.
|
||||
with pytest.raises(ValueError):
|
||||
Trainer(
|
||||
TestConfig(),
|
||||
CudaTestConfig(),
|
||||
num_workers=2,
|
||||
use_gpu=True,
|
||||
resources_per_worker={"GPU": 0})
|
||||
|
@ -1199,17 +1106,15 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
|
|||
def get_resources():
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1"
|
||||
|
||||
# 0 GPUs will be requested and should not raise an error.
|
||||
trainer = Trainer(TestConfig(), num_workers=2, use_gpu=False)
|
||||
trainer = Trainer(CudaTestConfig(), num_workers=2, use_gpu=False)
|
||||
trainer.start()
|
||||
result = trainer.run(get_resources)
|
||||
assert result == ["", ""]
|
||||
trainer.shutdown()
|
||||
|
||||
# 1 GPU will be requested and should not raise an error.
|
||||
trainer = Trainer(TestConfig(), num_workers=2, use_gpu=True)
|
||||
trainer = Trainer(CudaTestConfig(), num_workers=2, use_gpu=True)
|
||||
trainer.start()
|
||||
result = trainer.run(get_resources)
|
||||
assert result == ["0,1", "0,1"]
|
||||
|
@ -1217,7 +1122,7 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
|
|||
|
||||
# Partial GPUs should not raise an error.
|
||||
trainer = Trainer(
|
||||
TestConfig(),
|
||||
CudaTestConfig(),
|
||||
num_workers=2,
|
||||
use_gpu=True,
|
||||
resources_per_worker={"GPU": 0.1})
|
||||
|
@ -1228,7 +1133,7 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
|
|||
|
||||
# Multiple GPUs should not raise an error.
|
||||
trainer = Trainer(
|
||||
TestConfig(),
|
||||
CudaTestConfig(),
|
||||
num_workers=2,
|
||||
use_gpu=True,
|
||||
resources_per_worker={"GPU": 2})
|
||||
|
|
|
@ -6,19 +6,24 @@ from pathlib import Path
|
|||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \
|
||||
Type
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.train.backend import BackendConfig, BackendExecutor, \
|
||||
InactiveWorkerGroupError, TrainBackendError, TrainingWorkerError
|
||||
|
||||
from ray.train.callbacks.callback import TrainingCallback
|
||||
from ray.train.session import TrainingResultType
|
||||
from ray.train.utils import RayDataset
|
||||
from ray.train.checkpoint import CheckpointStrategy
|
||||
from ray.train.checkpoint import CheckpointStrategy, TuneCheckpointManager, \
|
||||
CheckpointManager
|
||||
from ray.train.constants import TUNE_INSTALLED, DEFAULT_RESULTS_DIR, \
|
||||
TUNE_CHECKPOINT_FILE_NAME
|
||||
|
||||
# Ray Train should be usable even if Tune is not installed.
|
||||
from ray.train.utils import construct_path
|
||||
from ray.train.worker_group import WorkerGroup
|
||||
from ray.util.ml_utils.node import force_on_current_node, \
|
||||
get_current_node_resource_key
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
|
@ -141,7 +146,14 @@ class Trainer:
|
|||
"request a positive number of `GPU` in "
|
||||
"`resources_per_worker.")
|
||||
|
||||
self._executor = BackendExecutor(
|
||||
remote_executor = ray.remote(num_cpus=0)(BackendExecutor)
|
||||
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
# Assign BackendExecutor to head node.
|
||||
remote_executor = force_on_current_node(remote_executor)
|
||||
|
||||
self._backend_executor_actor = remote_executor.remote(
|
||||
backend_config=backend_config,
|
||||
num_workers=num_workers,
|
||||
num_cpus_per_worker=num_cpus,
|
||||
|
@ -149,6 +161,12 @@ class Trainer:
|
|||
additional_resources_per_worker=resources_per_worker,
|
||||
max_retries=max_retries)
|
||||
|
||||
if tune is not None and tune.is_session_enabled():
|
||||
self.checkpoint_manager = TuneCheckpointManager()
|
||||
else:
|
||||
self.checkpoint_manager = CheckpointManager()
|
||||
self.checkpoint_manager.on_init()
|
||||
|
||||
def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path:
|
||||
"""Create logdir for the Trainer."""
|
||||
# Create directory for logs.
|
||||
|
@ -196,7 +214,7 @@ class Trainer:
|
|||
initialization_hook (Optional[Callable]): The function to call on
|
||||
each worker when it is instantiated.
|
||||
"""
|
||||
self._executor.start(initialization_hook)
|
||||
ray.get(self._backend_executor_actor.start.remote(initialization_hook))
|
||||
|
||||
def run(self,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
|
@ -255,9 +273,10 @@ class Trainer:
|
|||
|
||||
try:
|
||||
iterator = TrainingIterator(
|
||||
backend_executor=self._executor,
|
||||
backend_executor_actor=self._backend_executor_actor,
|
||||
train_func=train_func,
|
||||
dataset=dataset,
|
||||
checkpoint_manager=self.checkpoint_manager,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
run_dir=self.latest_run_dir,
|
||||
|
@ -329,10 +348,11 @@ class Trainer:
|
|||
train_func = self._get_train_func(train_func, config)
|
||||
|
||||
return TrainingIterator(
|
||||
backend_executor=self._executor,
|
||||
backend_executor_actor=self._backend_executor_actor,
|
||||
train_func=train_func,
|
||||
run_dir=self.latest_run_dir,
|
||||
dataset=dataset,
|
||||
checkpoint_manager=self.checkpoint_manager,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy)
|
||||
|
||||
|
@ -384,7 +404,7 @@ class Trainer:
|
|||
``train.checkpoint()`` has not been called from ``train_func``within
|
||||
the most recent call to ``run``.
|
||||
"""
|
||||
return self._executor.latest_checkpoint_dir
|
||||
return self.checkpoint_manager.latest_checkpoint_dir
|
||||
|
||||
@property
|
||||
def best_checkpoint_path(self) -> Optional[Path]:
|
||||
|
@ -397,7 +417,7 @@ class Trainer:
|
|||
``train.checkpoint()`` has not been called from ``train_func`` within
|
||||
the most recent call to ``run``.
|
||||
"""
|
||||
return self._executor.best_checkpoint_path
|
||||
return self.checkpoint_manager.best_checkpoint_path
|
||||
|
||||
@property
|
||||
def latest_checkpoint(self) -> Optional[Dict]:
|
||||
|
@ -408,11 +428,11 @@ class Trainer:
|
|||
Returns ``None`` if ``run()`` has not been called or if
|
||||
``train.checkpoint()`` has not been called from ``train_func``.
|
||||
"""
|
||||
return self._executor.latest_checkpoint
|
||||
return self.checkpoint_manager.latest_checkpoint
|
||||
|
||||
def shutdown(self):
|
||||
"""Shuts down the training execution service."""
|
||||
self._executor.shutdown()
|
||||
ray.get(self._backend_executor_actor.shutdown.remote())
|
||||
|
||||
def to_tune_trainable(
|
||||
self,
|
||||
|
@ -442,7 +462,7 @@ class Trainer:
|
|||
raise ValueError("Tune is not installed. Please install ray["
|
||||
"tune] to use the Tune integration.")
|
||||
|
||||
if self._executor.is_started:
|
||||
if ray.get(self._backend_executor_actor.is_started.remote()):
|
||||
raise RuntimeError("The Trainer must not be active to use "
|
||||
"`to_tune_trainable`. Either shutdown the "
|
||||
"Trainer or don't start it in the first place.")
|
||||
|
@ -483,13 +503,18 @@ class Trainer:
|
|||
args, kwargs: Arguments to pass into the ``__init__`` of the
|
||||
provided ``train_cls``.
|
||||
"""
|
||||
if self._executor.is_started:
|
||||
if ray.get(self._backend_executor_actor.is_started.remote()):
|
||||
raise RuntimeError("The Trainer must not be active to use "
|
||||
"`to_worker_group`. Either shutdown the "
|
||||
"Trainer or don't start it in the first place.")
|
||||
self._executor.start(
|
||||
train_cls=train_cls, train_cls_args=args, train_cls_kwargs=kwargs)
|
||||
return TrainWorkerGroup(self._executor.worker_group)
|
||||
ray.get(
|
||||
self._backend_executor_actor.start.remote(
|
||||
train_cls=train_cls,
|
||||
train_cls_args=args,
|
||||
train_cls_kwargs=kwargs))
|
||||
worker_group = ray.get(
|
||||
self._backend_executor_actor.get_worker_group.remote())
|
||||
return TrainWorkerGroup(worker_group)
|
||||
|
||||
|
||||
class TrainWorkerGroup:
|
||||
|
@ -541,16 +566,18 @@ class TrainingIterator:
|
|||
"""An iterator over Train results. Returned by ``trainer.run_iterator``."""
|
||||
|
||||
def __init__(
|
||||
self, backend_executor: BackendExecutor,
|
||||
self, backend_executor_actor: ActorHandle,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
run_dir: Path,
|
||||
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]],
|
||||
checkpoint: Optional[Dict],
|
||||
checkpoint_manager: CheckpointManager,
|
||||
checkpoint: Optional[Union[Dict, str, Path]],
|
||||
checkpoint_strategy: Optional[CheckpointStrategy]):
|
||||
self._executor = backend_executor
|
||||
self._backend_executor_actor = backend_executor_actor
|
||||
self._train_func = train_func
|
||||
self._dataset = dataset
|
||||
self._run_dir = run_dir
|
||||
self._checkpoint_manager = checkpoint_manager
|
||||
self._checkpoint_strategy = checkpoint_strategy
|
||||
self._start_training(
|
||||
train_func=train_func,
|
||||
|
@ -572,15 +599,18 @@ class TrainingIterator:
|
|||
checkpoint,
|
||||
checkpoint_strategy,
|
||||
latest_checkpoint_id=None):
|
||||
self._checkpoint_manager.on_start_training(
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
run_dir=run_dir,
|
||||
latest_checkpoint_id=latest_checkpoint_id)
|
||||
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
|
||||
self._run_with_error_handling(
|
||||
lambda: self._executor.start_training(
|
||||
lambda: ray.get(self._backend_executor_actor.start_training.remote(
|
||||
train_func=train_func,
|
||||
run_dir=run_dir,
|
||||
dataset=dataset,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy,
|
||||
latest_checkpoint_id=latest_checkpoint_id
|
||||
)
|
||||
checkpoint=checkpoint_dict
|
||||
))
|
||||
)
|
||||
|
||||
def _run_with_error_handling(self, func: Callable):
|
||||
|
@ -592,9 +622,10 @@ class TrainingIterator:
|
|||
self._train_func,
|
||||
self._run_dir,
|
||||
self._dataset,
|
||||
self._executor.latest_checkpoint,
|
||||
self._checkpoint_manager.latest_checkpoint,
|
||||
self._checkpoint_strategy,
|
||||
latest_checkpoint_id=self._executor.latest_checkpoint_id)
|
||||
latest_checkpoint_id=self._checkpoint_manager.
|
||||
latest_checkpoint_id)
|
||||
return self._run_with_error_handling(func)
|
||||
except InactiveWorkerGroupError:
|
||||
raise RuntimeError(
|
||||
|
@ -611,19 +642,78 @@ class TrainingIterator:
|
|||
def __next__(self):
|
||||
if self.is_finished():
|
||||
raise StopIteration
|
||||
next_results = self._run_with_error_handling(
|
||||
self._executor.fetch_next_result)
|
||||
next_results = self._run_with_error_handling(self._fetch_next_result)
|
||||
if next_results is None:
|
||||
try:
|
||||
self._final_results = \
|
||||
self._run_with_error_handling(
|
||||
self._executor.finish_training)
|
||||
self._final_results = self._run_with_error_handling(
|
||||
self._finish_training)
|
||||
finally:
|
||||
self._finished_training = True
|
||||
raise StopIteration
|
||||
else:
|
||||
|
||||
return next_results
|
||||
|
||||
def _fetch_next_result(self) -> Optional[List[Dict]]:
|
||||
"""Fetch next results produced by ``train.report()`` from each worker.
|
||||
|
||||
Assumes ``start_training`` has already been called.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries of values passed to ``train.report()`` from
|
||||
each worker. Each item corresponds to an intermediate result
|
||||
a single worker. If there are no more items to fetch,
|
||||
returns None.
|
||||
"""
|
||||
|
||||
while True:
|
||||
results = ray.get(
|
||||
self._backend_executor_actor.get_next_results.remote())
|
||||
if results is None:
|
||||
return None
|
||||
first_result = results[0]
|
||||
result_type = first_result.type
|
||||
if result_type is TrainingResultType.REPORT:
|
||||
result_data = [r.data for r in results]
|
||||
return result_data
|
||||
elif result_type is TrainingResultType.CHECKPOINT:
|
||||
self._checkpoint_manager._process_checkpoint(results)
|
||||
# Iterate until next REPORT call or training has finished.
|
||||
else:
|
||||
raise TrainBackendError(f"Unexpected result type: "
|
||||
f"{result_type}. "
|
||||
f"Expected one of "
|
||||
f"{[type in TrainingResultType]}")
|
||||
|
||||
def _finish_checkpointing(self):
|
||||
while True:
|
||||
results = ray.get(
|
||||
self._backend_executor_actor.get_next_results.remote())
|
||||
if results is None:
|
||||
break
|
||||
result_type = results[0].type
|
||||
# Process checkpoints and ignore other result types.
|
||||
if result_type is TrainingResultType.CHECKPOINT:
|
||||
self._checkpoint_manager._process_checkpoint(results)
|
||||
|
||||
def _finish_training(self):
|
||||
"""Finish training and return final results. Propagate any exceptions.
|
||||
|
||||
Blocks until training is finished on all workers.
|
||||
|
||||
Assumes `start_training` has already been called.
|
||||
|
||||
Returns:
|
||||
A list of return values from calling ``train_func`` on each worker.
|
||||
Each item corresponds to the return value from a single worker.
|
||||
"""
|
||||
|
||||
ray.get(self._backend_executor_actor.pause_reporting.remote())
|
||||
# Finish up processing checkpoints. Reporting has been disabled.
|
||||
# Results will not be processed.
|
||||
self._finish_checkpointing()
|
||||
return ray.get(self._backend_executor_actor.finish_training.remote())
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self._finished_training
|
||||
|
||||
|
@ -640,9 +730,8 @@ class TrainingIterator:
|
|||
assert self._final_results is None
|
||||
if force:
|
||||
try:
|
||||
self._final_results = \
|
||||
self._run_with_error_handling(
|
||||
self._executor.finish_training)
|
||||
self._final_results = self._run_with_error_handling(
|
||||
self._finish_training)
|
||||
finally:
|
||||
self._finished_training = True
|
||||
else:
|
||||
|
@ -697,15 +786,17 @@ def _create_tune_trainable(train_func, dataset, backend, num_workers, use_gpu,
|
|||
@classmethod
|
||||
def default_resource_request(cls,
|
||||
config: Dict) -> PlacementGroupFactory:
|
||||
head_bundle = [{"CPU": 1}] # driver
|
||||
node_resource_key = get_current_node_resource_key()
|
||||
trainer_bundle = [{"CPU": 1}]
|
||||
backend_executor_bundle = [{node_resource_key: 0.01}]
|
||||
worker_resources = {"CPU": 1, "GPU": int(use_gpu)}
|
||||
worker_resources_extra = {} if resources_per_worker is None else\
|
||||
worker_resources_extra = {} if resources_per_worker is None else \
|
||||
resources_per_worker
|
||||
worker_bundles = [{
|
||||
**worker_resources,
|
||||
**worker_resources_extra
|
||||
} for _ in range(num_workers)]
|
||||
bundles = head_bundle + worker_bundles
|
||||
bundles = trainer_bundle + backend_executor_bundle + worker_bundles
|
||||
return PlacementGroupFactory(bundles, strategy="PACK")
|
||||
|
||||
return TrainTrainable
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
driver_setup: train/driver_setup.sh
|
||||
|
||||
run:
|
||||
use_connect: True
|
||||
timeout: 36000
|
||||
script: python train/train_tensorflow_mnist_test.py
|
||||
|
||||
|
@ -43,6 +44,7 @@
|
|||
compute_template: train/compute_tpl.yaml
|
||||
|
||||
run:
|
||||
use_connect: True
|
||||
timeout: 36000
|
||||
script: python train/train_torch_linear_test.py
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
base_image: "anyscale/ray-ml:nightly-py37-gpu"
|
||||
env_vars:
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S: 2000
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S: "2000"
|
||||
|
||||
debian_packages:
|
||||
- curl
|
||||
|
|
|
@ -9,7 +9,7 @@ if __name__ == "__main__":
|
|||
start = time.time()
|
||||
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "train_tensorflow_mnist_test")
|
||||
|
||||
if addr is not None and addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name)
|
||||
|
@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|||
"time_taken": taken,
|
||||
}
|
||||
test_output_json = os.environ.get("TEST_OUTPUT_JSON",
|
||||
"/tmp/train_torc_linear_test.json")
|
||||
"/tmp/train_tensorflow_mnist_test.json")
|
||||
|
||||
with open(test_output_json, "wt") as f:
|
||||
json.dump(result, f)
|
||||
|
|
|
@ -10,7 +10,7 @@ if __name__ == "__main__":
|
|||
start = time.time()
|
||||
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "train_torch_linear_test")
|
||||
|
||||
if addr is not None and addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name)
|
||||
|
@ -22,7 +22,7 @@ if __name__ == "__main__":
|
|||
taken = time.time() - start
|
||||
result = {"time_taken": taken}
|
||||
test_output_json = os.environ.get("TEST_OUTPUT_JSON",
|
||||
"/tmp/train_torc_linear_test.json")
|
||||
"/tmp/train_torch_linear_test.json")
|
||||
|
||||
with open(test_output_json, "wt") as f:
|
||||
json.dump(result, f)
|
||||
|
|
Loading…
Add table
Reference in a new issue