[train] wrap BackendExecutor in ray.remote() (#20123)

* [train] wrap BackendExecutor in ray.remote()

* wip

* fix trainer tests

* move CheckpointManager to Trainer

* [tune] move force_on_current_node to ml_utils

* fix import

* force on head node

* init ray

* split test files

* update example

* move tests to ray client

* address comments

* move comment

* address comments
This commit is contained in:
matthewdeng 2021-11-13 15:30:44 -08:00 committed by GitHub
parent 9e2bd508d7
commit e22632dabc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 422 additions and 452 deletions

View file

@ -373,7 +373,10 @@ def train_func(config):
# Checkpoint model.
if is_distributed:
train.save_checkpoint(model_state_dict=net.module.state_dict())
import copy
model_copy = copy.deepcopy(net.module)
train.save_checkpoint(
model_state_dict=model_copy.cpu().state_dict())
else:
torch.save(net.state_dict(), f"models/model-epoch-{epoch}.torch")
@ -386,7 +389,7 @@ def train_func(config):
if is_distributed:
if train.world_rank() == 0:
return net.module
return net.module.cpu()
else:
return None
else:

View file

@ -93,6 +93,13 @@ py_test(
deps = [":train_lib"]
)
py_test(
name = "test_examples",
size = "large",
srcs = ["tests/test_examples.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib"]
)
py_test(
name = "test_gpu",

View file

@ -8,13 +8,10 @@ from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple
import ray
from ray.exceptions import RayActorError
from ray.ray_constants import env_integer
from ray.train.checkpoint import CheckpointManager, CheckpointStrategy, \
TuneCheckpointManager
from ray.train.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
TUNE_INSTALLED, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
TRAIN_ENABLE_WORKER_SPREAD_ENV, \
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV
from ray.train.session import TrainingResultType, TrainingResult
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, TRAIN_ENABLE_WORKER_SPREAD_ENV
from ray.train.session import TrainingResult
from ray.train.session import init_session, get_session, shutdown_session
from ray.train.utils import RayDataset
from ray.train.utils import check_for_failure
@ -22,11 +19,6 @@ from ray.train.worker_group import WorkerGroup
from ray.util.placement_group import get_current_placement_group, \
remove_placement_group
if TUNE_INSTALLED:
from ray import tune
else:
tune = None
T = TypeVar("T")
logger = logging.getLogger(__name__)
@ -67,15 +59,6 @@ class BackendExecutor:
and ``num_gpus_per_worker``.
max_retries (int): Number of retries when Ray actors fail.
Defaults to 3. Set to -1 for unlimited retries.
Attributes:
latest_checkpoint_dir (Optional[Path]): Path to the file directory for
the checkpoints from the latest run. Configured through
``start_training``
best_checkpoint_path (Optional[Path]): Path to the best persisted
checkpoint from the latest run.
latest_checkpoint (Optional[Dict]): The latest saved checkpoint. This
checkpoint may not be saved to disk.
"""
def __init__(
@ -99,16 +82,9 @@ class BackendExecutor:
self._initialization_hook = None
self._placement_group = None
if tune is not None and tune.is_session_enabled():
self.checkpoint_manager = TuneCheckpointManager()
else:
self.checkpoint_manager = CheckpointManager()
self.worker_group = InactiveWorkerGroup()
self.dataset_shards = None
self.checkpoint_manager.on_init()
def start(self,
initialization_hook: Optional[Callable[[], None]] = None,
train_cls: Optional[Type] = None,
@ -304,10 +280,7 @@ class BackendExecutor:
train_func: Callable[[], T],
run_dir: Path,
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
checkpoint: Optional[Union[Dict, str, Path]] = None,
checkpoint_strategy: Optional[CheckpointStrategy] = None,
latest_checkpoint_id: Optional[int] = None,
) -> None:
checkpoint: Optional[Dict] = None) -> None:
"""Executes a training function on all workers in a separate thread.
``finish_training`` should be called after this.
@ -324,22 +297,11 @@ class BackendExecutor:
and each Dataset can be accessed from the training function
by passing in a `dataset_name` argument to
``train.get_dataset_shard()``.
checkpoint (Optional[Dict|str|Path]): The checkpoint data that
checkpoint (Optional[Dict]): The checkpoint data that
should be loaded onto each worker and accessed by the
training function via ``train.load_checkpoint()``. If this is a
``str`` or ``Path`` then the value is expected to be a path
to a file that contains a serialized checkpoint dict. If this
training function via ``train.load_checkpoint()``. If this
is ``None`` then no checkpoint will be loaded.
checkpoint_strategy (Optional[CheckpointStrategy]): The
configurations for saving checkpoints.
latest_checkpoint_id (Optional[int]): The checkpoint id of the
most recently saved checkpoint.
"""
self.checkpoint_manager.on_start_training(
checkpoint_strategy=checkpoint_strategy,
run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id)
use_detailed_autofilled_metrics = env_integer(
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
@ -365,8 +327,6 @@ class BackendExecutor:
if self.dataset_shards is None:
self.dataset_shards = self._get_dataset_shards(dataset)
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
local_rank_map = self._create_local_rank_map()
futures = []
@ -379,7 +339,7 @@ class BackendExecutor:
local_rank=local_rank_map[index],
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint_dict))
checkpoint=checkpoint))
self.get_with_failure_handling(futures)
@ -390,7 +350,7 @@ class BackendExecutor:
self.worker_group.execute_async(train_async)
def _get_next_results(self) -> Optional[List[TrainingResult]]:
def get_next_results(self) -> Optional[List[TrainingResult]]:
"""Fetches the next ``TrainingResult`` from each worker.
Each ``TrainingResult`` is expected to correspond to the same step from
@ -403,24 +363,15 @@ class BackendExecutor:
"""
def get_next():
# Get the session for this worker.
try:
session = get_session()
except ValueError:
# Session is not initialized yet.
raise TrainBackendError("`fetch_next_result` has been called "
"before `start_training`. Please call "
"`start_training` before "
"`fetch_next_result`.")
session = _get_session("get_next_results")
try:
result = session.get_next()
except RuntimeError:
# Training thread has not been started yet.
raise TrainBackendError("`fetch_next_result` has been called "
raise TrainBackendError("`get_next_results` has been called "
"before `start_training`. Please call "
"`start_training` before "
"`fetch_next_result`.")
"`get_next_results`.")
return result
@ -451,37 +402,21 @@ class BackendExecutor:
"each worker.")
return results
def fetch_next_result(self) -> Optional[List[Dict]]:
"""Fetch next results produced by ``train.report()`` from each worker.
def pause_reporting(self):
""" Disable workers from enqueuing results from `train.report()`.
Assumes ``start_training`` has already been called.
Returns:
A list of dictionaries of values passed to ``train.report()`` from
each worker. Each item corresponds to an intermediate result
a single worker. If there are no more items to fetch,
returns None.
Note: Already reported results may still be enqueued at this point,
and should be handled appropriately.
"""
while True:
results = self._get_next_results()
if results is None:
return None
first_result = results[0]
result_type = first_result.type
if result_type is TrainingResultType.REPORT:
result_data = [r.data for r in results]
return result_data
elif result_type is TrainingResultType.CHECKPOINT:
self.checkpoint_manager._process_checkpoint(results)
# Iterate until next REPORT call or training has finished.
else:
raise TrainBackendError(f"Unexpected result type: "
f"{result_type}. "
f"Expected one of "
f"{[type in TrainingResultType]}")
def pause_session_reporting():
session = _get_session("pause_reporting")
return session.pause_reporting()
def finish_training(self) -> List[T]:
futures = self.worker_group.execute_async(pause_session_reporting)
self.get_with_failure_handling(futures)
def finish_training(self):
"""Finish training and return final results. Propagate any exceptions.
Blocks until training is finished on all workers.
@ -493,30 +428,8 @@ class BackendExecutor:
Each item corresponds to the return value from a single worker.
"""
def pause_reporting():
# Get the session for this worker.
try:
session = get_session()
except ValueError:
# Session is not initialized yet.
raise TrainBackendError("`finish_training` has been called "
"before `start_training`. Please call "
"`start_training` before "
"`finish_training`.")
return session.pause_reporting()
def end_training():
# Get the session for this worker.
try:
session = get_session()
except ValueError:
# Session is not initialized yet.
raise TrainBackendError("`finish_training` has been called "
"before `start_training`. Please call "
"`start_training` before "
"`finish_training`.")
session = _get_session("finish_training")
try:
# session.finish raises any Exceptions from training.
output = session.finish()
@ -527,23 +440,6 @@ class BackendExecutor:
return output
# Disable workers from enqueuing results from `train.report()`.
# Results will not be processed during the execution of `finish`.
# Note: Reported results may still be enqueued at this point,
# and should be handled appropriately.
futures = self.worker_group.execute_async(pause_reporting)
self.get_with_failure_handling(futures)
# Finish up processing checkpoints. Reporting has been disabled.
while True:
results = self._get_next_results()
if results is None:
break
result_type = results[0].type
# Process checkpoints and ignore other result types.
if result_type is TrainingResultType.CHECKPOINT:
self.checkpoint_manager._process_checkpoint(results)
futures = self.worker_group.execute_async(end_training)
results = self.get_with_failure_handling(futures)
return results
@ -594,37 +490,9 @@ class BackendExecutor:
self.dataset_shards = None
@property
def is_started(self):
return not isinstance(self.worker_group, InactiveWorkerGroup)
@property
def latest_checkpoint_dir(self) -> Optional[Path]:
"""Path to the latest checkpoint directory."""
return self.checkpoint_manager.latest_checkpoint_dir
@property
def best_checkpoint_path(self) -> Optional[Path]:
"""Path to the best persisted checkpoint."""
return self.checkpoint_manager.best_checkpoint_path
@property
def latest_checkpoint_id(self) -> Optional[int]:
"""The checkpoint id of most recently saved checkpoint.
If no checkpoint has been saved yet, then return None.
"""
checkpoint_id = self.checkpoint_manager._latest_checkpoint_id
if checkpoint_id == 0:
return None
else:
return checkpoint_id
@property
def latest_checkpoint(self) -> Optional[Dict]:
"""Latest checkpoint object."""
return self.checkpoint_manager.latest_checkpoint
def _restart(self):
self.worker_group.shutdown()
if self._initialization_hook is not None:
@ -646,6 +514,12 @@ class BackendExecutor:
"`max_retries` arg in your `Trainer`.") \
from None
def get_worker_group(self):
return self.worker_group
def _get_num_failures(self):
return self._num_failures
class Backend(metaclass=abc.ABCMeta):
"""Metaclass for distributed communication backend.
@ -702,3 +576,15 @@ class InactiveWorkerGroup():
def __len__(self):
raise InactiveWorkerGroupError()
def _get_session(method_name: str):
try:
# Get the session for this worker.
return get_session()
except ValueError:
# Session is not initialized yet.
raise TrainBackendError(f"`{method_name}` has been called "
"before `start_training`. Please call "
"`start_training` before "
f"`{method_name}`.")

View file

@ -274,6 +274,18 @@ class CheckpointManager:
else:
return None
@property
def latest_checkpoint_id(self) -> Optional[int]:
"""The checkpoint id of most recently saved checkpoint.
If no checkpoint has been saved yet, then return None.
"""
checkpoint_id = self._latest_checkpoint_id
if checkpoint_id == 0:
return None
else:
return checkpoint_id
class TuneCheckpointManager(CheckpointManager):
def create_logdir(self, log_dir: Optional[Union[str, Path]]):

View file

@ -49,7 +49,9 @@ def validate_epoch(dataloader, model, loss_fn, device):
pred = model(X)
loss += loss_fn(pred, y).item()
loss /= num_batches
result = {"model": model.state_dict(), "loss": loss}
import copy
model_copy = copy.deepcopy(model)
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
return result

View file

@ -1,6 +1,5 @@
import math
import os
import time
from unittest.mock import patch
import pytest
@ -160,7 +159,10 @@ def test_train_failure(ray_start_2_cpus, tmp_path):
e.start()
with pytest.raises(TrainBackendError):
e.fetch_next_result()
e.get_next_results()
with pytest.raises(TrainBackendError):
e.pause_reporting()
with pytest.raises(TrainBackendError):
e.finish_training()
@ -188,91 +190,6 @@ def test_worker_failure(ray_start_2_cpus, tmp_path):
e.finish_training()
def test_no_exhaust(ray_start_2_cpus, tmp_path):
"""Tests if training can finish even if queue is not exhausted."""
def train_func():
for _ in range(2):
train.report(loss=1)
return 2
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
e.start_training(train_func, run_dir=tmp_path)
output = e.finish_training()
assert output == [2, 2]
def test_checkpoint(ray_start_2_cpus, tmp_path):
def train_func():
for i in range(2):
train.save_checkpoint(epoch=i)
config = TestConfig()
e = BackendExecutor(config, num_workers=1)
e.start()
e.start_training(train_func, run_dir=tmp_path)
e.finish_training()
assert e.latest_checkpoint is not None
assert e.latest_checkpoint["epoch"] == 1
def test_persisted_checkpoint(ray_start_2_cpus, tmp_path):
def train_func():
for i in range(2):
train.save_checkpoint(epoch=i)
time.sleep(1)
config = TestConfig()
e = BackendExecutor(config)
e.start()
e.start_training(train_func, run_dir=tmp_path)
e.finish_training()
assert e.latest_checkpoint_id == 2
assert e.latest_checkpoint is not None
assert e.latest_checkpoint["epoch"] == 1
assert e.best_checkpoint_path is not None
assert os.path.exists(e.best_checkpoint_path)
def validate():
checkpoint = train.load_checkpoint()
assert checkpoint is not None
assert checkpoint["epoch"] == 1
e2 = BackendExecutor(config)
e2.start()
e2.start_training(
validate, checkpoint=e.best_checkpoint_path, run_dir=tmp_path)
e2.finish_training()
def test_persisted_checkpoint_id(ray_start_2_cpus, tmp_path):
def train_func():
for i in range(2):
train.save_checkpoint(epoch=i)
time.sleep(1)
config = TestConfig()
e = BackendExecutor(config)
e.start()
e.start_training(train_func, run_dir=tmp_path, latest_checkpoint_id=100)
e.finish_training()
assert e.latest_checkpoint_id == 102
assert e.latest_checkpoint is not None
assert e.latest_checkpoint["epoch"] == 1
assert e.best_checkpoint_path is not None
assert os.path.exists(e.best_checkpoint_path)
def test_mismatch_checkpoint_report(ray_start_2_cpus, tmp_path):
def train_func():
if (train.world_rank()) == 0:
@ -285,7 +202,7 @@ def test_mismatch_checkpoint_report(ray_start_2_cpus, tmp_path):
e.start()
e.start_training(train_func, run_dir=tmp_path)
with pytest.raises(RuntimeError):
e.finish_training()
e.get_next_results()
def test_tensorflow_start(ray_start_2_cpus, tmp_path):

View file

@ -0,0 +1,145 @@
import pytest
import ray
from ray.train import Trainer
from ray.train.examples.horovod.horovod_example import train_func as \
horovod_torch_train_func, HorovodTrainClass
from ray.train.examples.tensorflow_mnist_example import train_func as \
tensorflow_mnist_train_func
from ray.train.examples.tensorflow_quick_start import train_func as \
tf_quick_start_train_func
from ray.train.examples.torch_quick_start import train_func as \
torch_quick_start_train_func
from ray.train.examples.train_fashion_mnist_example import train_func \
as fashion_mnist_train_func
from ray.train.examples.train_linear_example import train_func as \
linear_train_func
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2])
def test_tensorflow_mnist(ray_start_2_cpus, num_workers):
num_workers = num_workers
epochs = 3
trainer = Trainer("tensorflow", num_workers=num_workers)
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
trainer.start()
results = trainer.run(tensorflow_mnist_train_func, config)
trainer.shutdown()
assert len(results) == num_workers
result = results[0]
loss = result["loss"]
assert len(loss) == epochs
assert loss[-1] < loss[0]
accuracy = result["accuracy"]
assert len(accuracy) == epochs
assert accuracy[-1] > accuracy[0]
def test_tf_non_distributed(ray_start_2_cpus):
"""Make sure Ray Train works without TF MultiWorkerMirroredStrategy."""
trainer = Trainer(backend="torch", num_workers=1)
trainer.start()
trainer.run(tf_quick_start_train_func)
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2])
def test_torch_linear(ray_start_2_cpus, num_workers):
num_workers = num_workers
epochs = 3
trainer = Trainer("torch", num_workers=num_workers)
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
trainer.start()
results = trainer.run(linear_train_func, config)
trainer.shutdown()
assert len(results) == num_workers
for result in results:
assert len(result) == epochs
assert result[-1]["loss"] < result[0]["loss"]
def test_torch_fashion_mnist(ray_start_2_cpus):
num_workers = 2
epochs = 3
trainer = Trainer("torch", num_workers=num_workers)
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
trainer.start()
results = trainer.run(fashion_mnist_train_func, config)
trainer.shutdown()
assert len(results) == num_workers
for result in results:
assert len(result) == epochs
assert result[-1] < result[0]
def test_torch_non_distributed(ray_start_2_cpus):
"""Make sure Ray Train works without torch DDP."""
trainer = Trainer(backend="torch", num_workers=1)
trainer.start()
trainer.run(torch_quick_start_train_func)
trainer.shutdown()
def test_horovod_torch_mnist(ray_start_2_cpus):
num_workers = 2
num_epochs = 2
trainer = Trainer("horovod", num_workers)
trainer.start()
results = trainer.run(
horovod_torch_train_func,
config={
"num_epochs": num_epochs,
"lr": 1e-3
})
trainer.shutdown()
assert len(results) == num_workers
for worker_result in results:
assert len(worker_result) == num_epochs
assert worker_result[num_epochs - 1] < worker_result[0]
def test_horovod_torch_mnist_stateful(ray_start_2_cpus):
num_workers = 2
num_epochs = 2
trainer = Trainer("horovod", num_workers)
workers = trainer.to_worker_group(
HorovodTrainClass, config={
"num_epochs": num_epochs,
"lr": 1e-3
})
results = []
for epoch in range(num_epochs):
results.append(ray.get([w.train.remote(epoch=epoch) for w in workers]))
trainer.shutdown()
assert len(results) == num_epochs
for i in range(num_workers):
assert results[num_epochs - 1][i] < results[0][i]
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -16,19 +16,6 @@ from ray.train.torch import TorchConfig
from ray.train.tensorflow import TensorflowConfig
from ray.train.horovod import HorovodConfig
from ray.train.callbacks.callback import TrainingCallback
from ray.train.constants import ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV
from ray.train.examples.horovod.horovod_example import train_func as \
horovod_torch_train_func, HorovodTrainClass
from ray.train.examples.tensorflow_mnist_example import train_func as \
tensorflow_mnist_train_func
from ray.train.examples.train_fashion_mnist_example import train_func \
as fashion_mnist_train_func
from ray.train.examples.train_linear_example import train_func as \
linear_train_func
from ray.train.examples.torch_quick_start import train_func as \
torch_quick_start_train_func
from ray.train.examples.tensorflow_quick_start import train_func as \
tf_quick_start_train_func
from ray.train.worker_group import WorkerGroup
@ -129,10 +116,11 @@ def gen_new_backend_executor(special_f):
class KillCallback(TrainingCallback):
def __init__(self, fail_on, worker_group):
def __init__(self, fail_on, trainer):
self.counter = 0
self.fail_on = fail_on
self.worker_group = worker_group
self.worker_group = ray.get(
trainer._backend_executor_actor.get_worker_group.remote())
def handle_result(self, results):
print(results)
@ -333,6 +321,24 @@ def test_run_iterator_error(ray_start_2_cpus):
assert iterator.is_finished()
def test_no_exhaust(ray_start_2_cpus, tmp_path):
"""Tests if training can finish even if queue is not exhausted."""
def train_func():
for _ in range(2):
train.report(loss=1)
return 2
config = TestConfig()
trainer = Trainer(config, num_workers=2)
trainer.start()
iterator = trainer.run_iterator(train_func)
output = iterator.get_final_results(force=True)
assert output == [2, 2]
def test_checkpoint(ray_start_2_cpus):
config = TestConfig()
@ -556,82 +562,6 @@ def test_world_rank(ray_start_2_cpus):
assert set(results) == {0, 1}
@pytest.mark.parametrize("num_workers", [1, 2])
def test_tensorflow_mnist(ray_start_2_cpus, num_workers):
num_workers = num_workers
epochs = 3
trainer = Trainer("tensorflow", num_workers=num_workers)
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
trainer.start()
results = trainer.run(tensorflow_mnist_train_func, config)
trainer.shutdown()
assert len(results) == num_workers
result = results[0]
loss = result["loss"]
assert len(loss) == epochs
assert loss[-1] < loss[0]
accuracy = result["accuracy"]
assert len(accuracy) == epochs
assert accuracy[-1] > accuracy[0]
def test_tf_non_distributed(ray_start_2_cpus):
"""Make sure Ray Train works without TF MultiWorkerMirroredStrategy."""
trainer = Trainer(backend="torch", num_workers=1)
trainer.start()
trainer.run(tf_quick_start_train_func)
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2])
def test_torch_linear(ray_start_2_cpus, num_workers):
num_workers = num_workers
epochs = 3
trainer = Trainer("torch", num_workers=num_workers)
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
trainer.start()
results = trainer.run(linear_train_func, config)
trainer.shutdown()
assert len(results) == num_workers
for result in results:
assert len(result) == epochs
assert result[-1]["loss"] < result[0]["loss"]
def test_torch_fashion_mnist(ray_start_2_cpus):
num_workers = 2
epochs = 3
trainer = Trainer("torch", num_workers=num_workers)
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
trainer.start()
results = trainer.run(fashion_mnist_train_func, config)
trainer.shutdown()
assert len(results) == num_workers
for result in results:
assert len(result) == epochs
assert result[-1] < result[0]
def test_torch_non_distributed(ray_start_2_cpus):
"""Make sure Ray Train works without torch DDP."""
trainer = Trainer(backend="torch", num_workers=1)
trainer.start()
trainer.run(torch_quick_start_train_func)
trainer.shutdown()
def test_horovod_simple(ray_start_2_cpus):
def simple_fn():
hvd_torch.init()
@ -646,44 +576,6 @@ def test_horovod_simple(ray_start_2_cpus):
assert result == list(range(num_workers))
def test_horovod_torch_mnist(ray_start_2_cpus):
num_workers = 2
num_epochs = 2
trainer = Trainer("horovod", num_workers)
trainer.start()
results = trainer.run(
horovod_torch_train_func,
config={
"num_epochs": num_epochs,
"lr": 1e-3
})
trainer.shutdown()
assert len(results) == num_workers
for worker_result in results:
assert len(worker_result) == num_epochs
assert worker_result[num_epochs - 1] < worker_result[0]
def test_horovod_torch_mnist_stateful(ray_start_2_cpus):
num_workers = 2
num_epochs = 2
trainer = Trainer("horovod", num_workers)
workers = trainer.to_worker_group(
HorovodTrainClass, config={
"num_epochs": num_epochs,
"lr": 1e-3
})
results = []
for epoch in range(num_epochs):
results.append(ray.get([w.train.remote(epoch=epoch) for w in workers]))
trainer.shutdown()
assert len(results) == num_epochs
for i in range(num_workers):
assert results[num_epochs - 1][i] < results[0][i]
def test_init_failure(ray_start_2_cpus):
with pytest.raises(TypeError):
Trainer(5, num_workers=2)
@ -809,23 +701,27 @@ def test_worker_failure_local_rank(ray_start_2_cpus):
def test_worker_start_failure(ray_start_2_cpus):
test_config = TestConfig()
trainer = Trainer(test_config, num_workers=2)
restart = trainer._executor._restart
def init_hook():
pass
def init_hook_fail():
ray.actor.exit_actor()
def restart_patched(self):
self._initialization_hook = init_hook
restart()
class TestBackendExecutor(BackendExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with patch.object(BackendExecutor, "_restart", restart_patched):
def _restart(self):
self._initialization_hook = init_hook
super()._restart()
with patch.object(ray.train.trainer, "BackendExecutor",
TestBackendExecutor):
trainer = Trainer(test_config, num_workers=2)
trainer.start(initialization_hook=init_hook_fail)
assert len(trainer._executor.worker_group) == 2
assert len(
ray.get(trainer._backend_executor_actor.get_worker_group.remote())
) == 2
def test_max_failures(ray_start_2_cpus):
@ -840,7 +736,8 @@ def test_max_failures(ray_start_2_cpus):
iterator = trainer.run_iterator(train_func)
with pytest.raises(RuntimeError):
iterator.get_final_results(force=True)
assert iterator._executor._num_failures == 3
assert ray.get(
iterator._backend_executor_actor._get_num_failures.remote()) == 3
def test_start_max_failures(ray_start_2_cpus):
@ -874,8 +771,7 @@ def test_worker_kill(ray_start_2_cpus, backend):
train.report(loss=1, iter=i)
trainer.start()
kill_callback = KillCallback(
fail_on=0, worker_group=trainer._executor.worker_group)
kill_callback = KillCallback(fail_on=0, trainer=trainer)
trainer.run(train_func, callbacks=[kill_callback])
# Run 1: iter=0, counter=1, Successful
# Run 2: iter=1, counter=1, Unsuccessful, starts training from beginning
@ -886,8 +782,7 @@ def test_worker_kill(ray_start_2_cpus, backend):
trainer.shutdown()
trainer.start()
kill_callback = KillCallback(
fail_on=1, worker_group=trainer._executor.worker_group)
kill_callback = KillCallback(fail_on=1, trainer=trainer)
trainer.run(train_func, callbacks=[kill_callback])
# Run 1: iter=0, counter=1, Successful
# Run 2: iter=1, counter=2, Successful
@ -919,8 +814,7 @@ def test_worker_kill_checkpoint(ray_start_2_cpus):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
kill_callback = KillCallback(
fail_on=0, worker_group=trainer._executor.worker_group)
kill_callback = KillCallback(fail_on=0, trainer=trainer)
trainer.run(train_func, callbacks=[kill_callback])
@ -936,8 +830,7 @@ def test_worker_kill_checkpoint(ray_start_2_cpus):
trainer.shutdown()
trainer.start()
kill_callback = KillCallback(
fail_on=1, worker_group=trainer._executor.worker_group)
kill_callback = KillCallback(fail_on=1, trainer=trainer)
trainer.run(train_func, callbacks=[kill_callback])
# Run 1: epoch=0, counter=1, Successful
# *Checkpoint saved*
@ -1131,11 +1024,10 @@ def test_dataset_pipeline_shuffle(ray_start_4_cpus):
def test_dataset_fault_tolerance(ray_start_4_cpus):
dataset = ray.data.range(10)
dataset_splits = dataset.split(n=2, equal=True)
test_config = TestConfig()
def train_func():
return 1
return train.get_dataset_shard()
def train_actor_failure():
import sys
@ -1143,16 +1035,23 @@ def test_dataset_fault_tolerance(ray_start_4_cpus):
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
class SingleGetDatasetShardsBackendExecutor(new_backend_executor_cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._has_called_get_dataset_shards = False
def _get_dataset_shards(self, dataset_or_dict):
if self._has_called_get_dataset_shards:
raise Exception
self._has_called_get_dataset_shards = True
return super()._get_dataset_shards(dataset_or_dict)
with patch.object(ray.train.trainer, "BackendExecutor",
new_backend_executor_cls):
with patch.object(
new_backend_executor_cls,
"_get_dataset_shards",
return_value=dataset_splits) as mock_method:
trainer = Trainer(test_config, num_workers=2)
trainer.start()
trainer.run(train_func, dataset=dataset)
mock_method.assert_called_once()
SingleGetDatasetShardsBackendExecutor):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
trainer.run(train_func, dataset=dataset)
# No exception is raised by _get_dataset_shards
@pytest.mark.parametrize("resource", ["CPU", "GPU", "extra"])
@ -1180,10 +1079,18 @@ def test_resources(ray_start_4_cpus_4_gpus_4_extra, resource, num_requested):
def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
class CudaTestBackend(TestBackend):
share_cuda_visible_devices = True
class CudaTestConfig(TestConfig):
@property
def backend_cls(self):
return CudaTestBackend
# GPUs should not be requested if `use_gpu` is False.
with pytest.raises(ValueError):
Trainer(
TestConfig(),
CudaTestConfig(),
num_workers=2,
use_gpu=False,
resources_per_worker={"GPU": 1})
@ -1191,7 +1098,7 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
# GPUs should not be set to 0 if `use_gpu` is True.
with pytest.raises(ValueError):
Trainer(
TestConfig(),
CudaTestConfig(),
num_workers=2,
use_gpu=True,
resources_per_worker={"GPU": 0})
@ -1199,17 +1106,15 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
def get_resources():
return os.environ["CUDA_VISIBLE_DEVICES"]
os.environ[ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV] = "1"
# 0 GPUs will be requested and should not raise an error.
trainer = Trainer(TestConfig(), num_workers=2, use_gpu=False)
trainer = Trainer(CudaTestConfig(), num_workers=2, use_gpu=False)
trainer.start()
result = trainer.run(get_resources)
assert result == ["", ""]
trainer.shutdown()
# 1 GPU will be requested and should not raise an error.
trainer = Trainer(TestConfig(), num_workers=2, use_gpu=True)
trainer = Trainer(CudaTestConfig(), num_workers=2, use_gpu=True)
trainer.start()
result = trainer.run(get_resources)
assert result == ["0,1", "0,1"]
@ -1217,7 +1122,7 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
# Partial GPUs should not raise an error.
trainer = Trainer(
TestConfig(),
CudaTestConfig(),
num_workers=2,
use_gpu=True,
resources_per_worker={"GPU": 0.1})
@ -1228,7 +1133,7 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
# Multiple GPUs should not raise an error.
trainer = Trainer(
TestConfig(),
CudaTestConfig(),
num_workers=2,
use_gpu=True,
resources_per_worker={"GPU": 2})

View file

@ -6,19 +6,24 @@ from pathlib import Path
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \
Type
import ray
from ray.actor import ActorHandle
from ray.train.backend import BackendConfig, BackendExecutor, \
InactiveWorkerGroupError, TrainBackendError, TrainingWorkerError
from ray.train.callbacks.callback import TrainingCallback
from ray.train.session import TrainingResultType
from ray.train.utils import RayDataset
from ray.train.checkpoint import CheckpointStrategy
from ray.train.checkpoint import CheckpointStrategy, TuneCheckpointManager, \
CheckpointManager
from ray.train.constants import TUNE_INSTALLED, DEFAULT_RESULTS_DIR, \
TUNE_CHECKPOINT_FILE_NAME
# Ray Train should be usable even if Tune is not installed.
from ray.train.utils import construct_path
from ray.train.worker_group import WorkerGroup
from ray.util.ml_utils.node import force_on_current_node, \
get_current_node_resource_key
if TUNE_INSTALLED:
from ray import tune
@ -141,7 +146,14 @@ class Trainer:
"request a positive number of `GPU` in "
"`resources_per_worker.")
self._executor = BackendExecutor(
remote_executor = ray.remote(num_cpus=0)(BackendExecutor)
if not ray.is_initialized():
ray.init()
# Assign BackendExecutor to head node.
remote_executor = force_on_current_node(remote_executor)
self._backend_executor_actor = remote_executor.remote(
backend_config=backend_config,
num_workers=num_workers,
num_cpus_per_worker=num_cpus,
@ -149,6 +161,12 @@ class Trainer:
additional_resources_per_worker=resources_per_worker,
max_retries=max_retries)
if tune is not None and tune.is_session_enabled():
self.checkpoint_manager = TuneCheckpointManager()
else:
self.checkpoint_manager = CheckpointManager()
self.checkpoint_manager.on_init()
def create_logdir(self, log_dir: Optional[Union[str, Path]]) -> Path:
"""Create logdir for the Trainer."""
# Create directory for logs.
@ -196,7 +214,7 @@ class Trainer:
initialization_hook (Optional[Callable]): The function to call on
each worker when it is instantiated.
"""
self._executor.start(initialization_hook)
ray.get(self._backend_executor_actor.start.remote(initialization_hook))
def run(self,
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
@ -255,9 +273,10 @@ class Trainer:
try:
iterator = TrainingIterator(
backend_executor=self._executor,
backend_executor_actor=self._backend_executor_actor,
train_func=train_func,
dataset=dataset,
checkpoint_manager=self.checkpoint_manager,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
run_dir=self.latest_run_dir,
@ -329,10 +348,11 @@ class Trainer:
train_func = self._get_train_func(train_func, config)
return TrainingIterator(
backend_executor=self._executor,
backend_executor_actor=self._backend_executor_actor,
train_func=train_func,
run_dir=self.latest_run_dir,
dataset=dataset,
checkpoint_manager=self.checkpoint_manager,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy)
@ -384,7 +404,7 @@ class Trainer:
``train.checkpoint()`` has not been called from ``train_func``within
the most recent call to ``run``.
"""
return self._executor.latest_checkpoint_dir
return self.checkpoint_manager.latest_checkpoint_dir
@property
def best_checkpoint_path(self) -> Optional[Path]:
@ -397,7 +417,7 @@ class Trainer:
``train.checkpoint()`` has not been called from ``train_func`` within
the most recent call to ``run``.
"""
return self._executor.best_checkpoint_path
return self.checkpoint_manager.best_checkpoint_path
@property
def latest_checkpoint(self) -> Optional[Dict]:
@ -408,11 +428,11 @@ class Trainer:
Returns ``None`` if ``run()`` has not been called or if
``train.checkpoint()`` has not been called from ``train_func``.
"""
return self._executor.latest_checkpoint
return self.checkpoint_manager.latest_checkpoint
def shutdown(self):
"""Shuts down the training execution service."""
self._executor.shutdown()
ray.get(self._backend_executor_actor.shutdown.remote())
def to_tune_trainable(
self,
@ -442,7 +462,7 @@ class Trainer:
raise ValueError("Tune is not installed. Please install ray["
"tune] to use the Tune integration.")
if self._executor.is_started:
if ray.get(self._backend_executor_actor.is_started.remote()):
raise RuntimeError("The Trainer must not be active to use "
"`to_tune_trainable`. Either shutdown the "
"Trainer or don't start it in the first place.")
@ -483,13 +503,18 @@ class Trainer:
args, kwargs: Arguments to pass into the ``__init__`` of the
provided ``train_cls``.
"""
if self._executor.is_started:
if ray.get(self._backend_executor_actor.is_started.remote()):
raise RuntimeError("The Trainer must not be active to use "
"`to_worker_group`. Either shutdown the "
"Trainer or don't start it in the first place.")
self._executor.start(
train_cls=train_cls, train_cls_args=args, train_cls_kwargs=kwargs)
return TrainWorkerGroup(self._executor.worker_group)
ray.get(
self._backend_executor_actor.start.remote(
train_cls=train_cls,
train_cls_args=args,
train_cls_kwargs=kwargs))
worker_group = ray.get(
self._backend_executor_actor.get_worker_group.remote())
return TrainWorkerGroup(worker_group)
class TrainWorkerGroup:
@ -541,16 +566,18 @@ class TrainingIterator:
"""An iterator over Train results. Returned by ``trainer.run_iterator``."""
def __init__(
self, backend_executor: BackendExecutor,
self, backend_executor_actor: ActorHandle,
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
run_dir: Path,
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]],
checkpoint: Optional[Dict],
checkpoint_manager: CheckpointManager,
checkpoint: Optional[Union[Dict, str, Path]],
checkpoint_strategy: Optional[CheckpointStrategy]):
self._executor = backend_executor
self._backend_executor_actor = backend_executor_actor
self._train_func = train_func
self._dataset = dataset
self._run_dir = run_dir
self._checkpoint_manager = checkpoint_manager
self._checkpoint_strategy = checkpoint_strategy
self._start_training(
train_func=train_func,
@ -572,15 +599,18 @@ class TrainingIterator:
checkpoint,
checkpoint_strategy,
latest_checkpoint_id=None):
self._checkpoint_manager.on_start_training(
checkpoint_strategy=checkpoint_strategy,
run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id)
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
self._run_with_error_handling(
lambda: self._executor.start_training(
lambda: ray.get(self._backend_executor_actor.start_training.remote(
train_func=train_func,
run_dir=run_dir,
dataset=dataset,
checkpoint=checkpoint,
checkpoint_strategy=checkpoint_strategy,
latest_checkpoint_id=latest_checkpoint_id
)
checkpoint=checkpoint_dict
))
)
def _run_with_error_handling(self, func: Callable):
@ -592,9 +622,10 @@ class TrainingIterator:
self._train_func,
self._run_dir,
self._dataset,
self._executor.latest_checkpoint,
self._checkpoint_manager.latest_checkpoint,
self._checkpoint_strategy,
latest_checkpoint_id=self._executor.latest_checkpoint_id)
latest_checkpoint_id=self._checkpoint_manager.
latest_checkpoint_id)
return self._run_with_error_handling(func)
except InactiveWorkerGroupError:
raise RuntimeError(
@ -611,19 +642,78 @@ class TrainingIterator:
def __next__(self):
if self.is_finished():
raise StopIteration
next_results = self._run_with_error_handling(
self._executor.fetch_next_result)
next_results = self._run_with_error_handling(self._fetch_next_result)
if next_results is None:
try:
self._final_results = \
self._run_with_error_handling(
self._executor.finish_training)
self._final_results = self._run_with_error_handling(
self._finish_training)
finally:
self._finished_training = True
raise StopIteration
else:
return next_results
def _fetch_next_result(self) -> Optional[List[Dict]]:
"""Fetch next results produced by ``train.report()`` from each worker.
Assumes ``start_training`` has already been called.
Returns:
A list of dictionaries of values passed to ``train.report()`` from
each worker. Each item corresponds to an intermediate result
a single worker. If there are no more items to fetch,
returns None.
"""
while True:
results = ray.get(
self._backend_executor_actor.get_next_results.remote())
if results is None:
return None
first_result = results[0]
result_type = first_result.type
if result_type is TrainingResultType.REPORT:
result_data = [r.data for r in results]
return result_data
elif result_type is TrainingResultType.CHECKPOINT:
self._checkpoint_manager._process_checkpoint(results)
# Iterate until next REPORT call or training has finished.
else:
raise TrainBackendError(f"Unexpected result type: "
f"{result_type}. "
f"Expected one of "
f"{[type in TrainingResultType]}")
def _finish_checkpointing(self):
while True:
results = ray.get(
self._backend_executor_actor.get_next_results.remote())
if results is None:
break
result_type = results[0].type
# Process checkpoints and ignore other result types.
if result_type is TrainingResultType.CHECKPOINT:
self._checkpoint_manager._process_checkpoint(results)
def _finish_training(self):
"""Finish training and return final results. Propagate any exceptions.
Blocks until training is finished on all workers.
Assumes `start_training` has already been called.
Returns:
A list of return values from calling ``train_func`` on each worker.
Each item corresponds to the return value from a single worker.
"""
ray.get(self._backend_executor_actor.pause_reporting.remote())
# Finish up processing checkpoints. Reporting has been disabled.
# Results will not be processed.
self._finish_checkpointing()
return ray.get(self._backend_executor_actor.finish_training.remote())
def is_finished(self) -> bool:
return self._finished_training
@ -640,9 +730,8 @@ class TrainingIterator:
assert self._final_results is None
if force:
try:
self._final_results = \
self._run_with_error_handling(
self._executor.finish_training)
self._final_results = self._run_with_error_handling(
self._finish_training)
finally:
self._finished_training = True
else:
@ -697,15 +786,17 @@ def _create_tune_trainable(train_func, dataset, backend, num_workers, use_gpu,
@classmethod
def default_resource_request(cls,
config: Dict) -> PlacementGroupFactory:
head_bundle = [{"CPU": 1}] # driver
node_resource_key = get_current_node_resource_key()
trainer_bundle = [{"CPU": 1}]
backend_executor_bundle = [{node_resource_key: 0.01}]
worker_resources = {"CPU": 1, "GPU": int(use_gpu)}
worker_resources_extra = {} if resources_per_worker is None else\
worker_resources_extra = {} if resources_per_worker is None else \
resources_per_worker
worker_bundles = [{
**worker_resources,
**worker_resources_extra
} for _ in range(num_workers)]
bundles = head_bundle + worker_bundles
bundles = trainer_bundle + backend_executor_bundle + worker_bundles
return PlacementGroupFactory(bundles, strategy="PACK")
return TrainTrainable

View file

@ -34,6 +34,7 @@
driver_setup: train/driver_setup.sh
run:
use_connect: True
timeout: 36000
script: python train/train_tensorflow_mnist_test.py
@ -43,6 +44,7 @@
compute_template: train/compute_tpl.yaml
run:
use_connect: True
timeout: 36000
script: python train/train_torch_linear_test.py

View file

@ -1,6 +1,6 @@
base_image: "anyscale/ray-ml:nightly-py37-gpu"
env_vars:
TRAIN_PLACEMENT_GROUP_TIMEOUT_S: 2000
TRAIN_PLACEMENT_GROUP_TIMEOUT_S: "2000"
debian_packages:
- curl

View file

@ -9,7 +9,7 @@ if __name__ == "__main__":
start = time.time()
addr = os.environ.get("RAY_ADDRESS")
job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test")
job_name = os.environ.get("RAY_JOB_NAME", "train_tensorflow_mnist_test")
if addr is not None and addr.startswith("anyscale://"):
ray.init(address=addr, job_name=job_name)
@ -23,7 +23,7 @@ if __name__ == "__main__":
"time_taken": taken,
}
test_output_json = os.environ.get("TEST_OUTPUT_JSON",
"/tmp/train_torc_linear_test.json")
"/tmp/train_tensorflow_mnist_test.json")
with open(test_output_json, "wt") as f:
json.dump(result, f)

View file

@ -10,7 +10,7 @@ if __name__ == "__main__":
start = time.time()
addr = os.environ.get("RAY_ADDRESS")
job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test")
job_name = os.environ.get("RAY_JOB_NAME", "train_torch_linear_test")
if addr is not None and addr.startswith("anyscale://"):
ray.init(address=addr, job_name=job_name)
@ -22,7 +22,7 @@ if __name__ == "__main__":
taken = time.time() - start
result = {"time_taken": taken}
test_output_json = os.environ.get("TEST_OUTPUT_JSON",
"/tmp/train_torc_linear_test.json")
"/tmp/train_torch_linear_test.json")
with open(test_output_json, "wt") as f:
json.dump(result, f)