[SGD] Worker Startup Fault Tolerance (#14724)

This commit is contained in:
Amog Kamsetty 2021-03-18 22:53:56 -07:00 committed by GitHub
parent c30d5f445c
commit 47300d5a53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 188 additions and 61 deletions

View file

@ -1,6 +1,8 @@
from unittest.mock import patch
import pytest
import time
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
@ -59,12 +61,12 @@ start_workers = TorchTrainer._start_workers
def gen_start_with_fail(num_fails):
def start_with_fail(self, *args, **kwargs):
start_workers(self, *args, **kwargs)
success = start_workers(self, *args, **kwargs)
fail = self._num_failures < num_fails
if self.use_local:
self.worker_group.remote_worker_group.should_fail = fail
else:
self.worker_group.should_fail = fail
fail_worker_group = self.worker_group.remote_worker_group if \
self.use_local else self.worker_group
fail_worker_group.should_fail = fail
return success
return start_with_fail
@ -178,6 +180,112 @@ def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811
trainer1.shutdown(force=True)
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
def test_fail_state(ray_start_2_cpus): # noqa: F811
"""Tests if state of training with failure is same as training without."""
if not dist.is_available():
return
torch.manual_seed(0)
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
def init_hook():
torch.manual_seed(0)
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
timeout_s=5,
initialization_hook=init_hook,
num_workers=2)
initial_state = trainer1.state_dict()
trainer1.train()
trainer1_state = trainer1.state_dict()
assert trainer1_state != initial_state
trainer1.shutdown()
trainer2 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
timeout_s=5,
initialization_hook=init_hook,
num_workers=2)
trainer2.load_state_dict(initial_state)
trainer2.train()
assert trainer2.state_dict() == trainer1_state
trainer2.shutdown()
start_with_fail = gen_start_with_fail(1)
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
trainer3 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
timeout_s=5,
initialization_hook=init_hook,
num_workers=2)
trainer3.load_state_dict(initial_state)
trainer3.train()
assert trainer3.state_dict() == trainer1_state
trainer3.shutdown()
def gen_start_with_startup_fail(num_fails):
fail_start = gen_start_with_fail(num_fails)
def start_with_fail(self, *args, **kwargs):
if hasattr(self, "worker_group"):
# Fail during worker start just during the first training attempt.
def _raise():
import sys
sys.exit(1)
if not self.initialization_hook:
self.initialization_hook = _raise
else:
self.initialization_hook = None
return fail_start(self, *args, **kwargs)
return start_with_fail
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
def test_failure_during_resize(ray_start_2_cpus): # noqa: F811
"""Tests if training succeeds even with failures during worker resizing."""
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
start_with_fail = gen_start_with_startup_fail(1)
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
timeout_s=5,
use_local=False,
num_workers=2)
trainer1.train()
trainer1.shutdown()
if __name__ == "__main__":
import pytest
import sys

View file

@ -263,7 +263,11 @@ class TorchTrainer:
"multi-node training, be sure to run `ray.init("
"address='auto')` before instantiating the Trainer.")
ray.init()
self._start_workers(self.max_replicas)
startup_success = self._start_workers(self.max_replicas)
if not startup_success:
raise RuntimeError("Worker startup failed. "
"Are you sure you have enough resources to "
"start the specified number of workers?")
def _configure_and_split_batch(self, num_workers):
"""If sgd.utils.BATCH_SIZE is provided, split among workers."""
@ -323,17 +327,17 @@ class TorchTrainer:
# num_workers workers, this command will hang. Instead,
# start_workers should take into account available resources when
# determining how many workers to create.
self.worker_group.start_workers(num_workers)
return self.worker_group.start_workers(num_workers)
def _resize_worker_group(self, max_retries=10):
def _resize_worker_group(self, state_dict, max_retries=10):
"""Resizes the number of remote workers based on available resources.
Total number of workers will never exceed `num_workers` amount.
Args:
state_dict (dict): The state dict to load to all workers.
max_retries (int): How many times to attempt to resize workers
before failing.
"""
state_dict = self.state_dict()
old_workers = self.worker_group.num_workers
self.worker_group.reset()
@ -342,7 +346,12 @@ class TorchTrainer:
new_workers = self.worker_group.new_workers_size()
if new_workers:
self._last_resize = time.time()
self._start_workers(int(new_workers))
startup_success = self._start_workers(int(new_workers))
if not startup_success:
logger.info(f"Worker startup failed. Retrying "
f"{max_retries-i-1} more times.")
self.worker_group.reset()
continue
self.load_state_dict(state_dict, blocking=True)
if self.use_local and new_workers == 1 and old_workers > 1:
# Major hack. If we go from LocalDistributedRunner to a
@ -408,9 +417,10 @@ class TorchTrainer:
assert isinstance(dataset, Dataset) is not None \
or self.data_creator, \
"Must specify either a data creator or a dataset"
state_dict = self.state_dict()
if self.worker_group.should_scale_up():
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_worker_group()
self._resize_worker_group(state_dict)
success, worker_stats = self.worker_group.train(
num_steps=num_steps, profile=profile, info=info, dataset=dataset)
# Fault handling
@ -419,7 +429,7 @@ class TorchTrainer:
break
else:
self._num_failures += 1
self._resize_worker_group()
self._resize_worker_group(state_dict)
logger.info("Retrying training step with %d workers." %
self.worker_group.num_workers)
success, worker_stats = self.worker_group.train(

View file

@ -21,7 +21,7 @@ class WorkerGroupInterface:
"""Manages a group of TorchRunner workers."""
def start_workers(self, num_workers):
"""Start workers for training.
"""Start workers for training. Returns if startup is successful.
This method has 4 steps.
1. Creates `num_workers` TorchRunner objects, either all as remote
@ -207,28 +207,33 @@ class RemoteWorkerGroup(WorkerGroupInterface):
def start_workers(self, num_workers):
logger.debug(f"start_workers: Setting %d workers." % num_workers)
if num_workers == 1:
RemoteRunner = ray.remote(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu))(TorchRunner)
self.remote_workers = [RemoteRunner.remote(**self._params)]
ray.get(self.remote_workers[0].setup_operator.remote())
else:
self._init_dist_workers(num_workers)
try:
if num_workers == 1:
RemoteRunner = ray.remote(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu))(TorchRunner)
self.remote_workers = [RemoteRunner.remote(**self._params)]
ray.get(self.remote_workers[0].setup_operator.remote())
else:
self._init_dist_workers(num_workers)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
# Make sure to get the IP address of the rank 0 worker node.
address = ray.get(self.remote_workers[0].setup_address.remote())
# Make sure to get the IP address of the rank 0 worker node.
address = ray.get(
self.remote_workers[0].setup_address.remote())
ray.get(
self._setup_process_group(
address=address, world_size=num_workers))
ray.get(
self._setup_process_group(
address=address, world_size=num_workers))
ray.get(self._setup_local_rank())
ray.get(self._setup_local_rank())
ray.get(self._setup_operator())
ray.get(self._setup_operator())
return True
except RayActorError:
return False
def _apply_all_operators(self, fn):
remote_calls = [
@ -443,42 +448,46 @@ class LocalWorkerGroup(WorkerGroupInterface):
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
self.local_worker.setup_operator()
return True
else:
try:
# Start local worker
self.local_worker = LocalDistributedRunner(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu),
**{
**self._params,
**self._dist_params
})
self.remote_worker_group._init_dist_workers(num_workers - 1)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
# Start local worker
self.local_worker = LocalDistributedRunner(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu),
**{
**self._params,
**self._dist_params
})
self.remote_worker_group._init_dist_workers(num_workers - 1)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
# Compute URL for initializing distributed PyTorch.
address = setup_address()
# Compute URL for initializing distributed PyTorch.
address = setup_address()
remote_pgs = self.remote_worker_group._setup_process_group(
address=address, world_size=num_workers, starting_rank=1)
# Use the local worker as rank 0. Helps with debugging.
self.local_worker.setup_process_group(
url=address,
world_rank=0,
world_size=num_workers,
timeout=timedelta(seconds=self._timeout_s))
ray.get(remote_pgs)
remote_pgs = self.remote_worker_group._setup_process_group(
address=address, world_size=num_workers, starting_rank=1)
# Use the local worker as rank 0. This will help with debugging.
self.local_worker.setup_process_group(
url=address,
world_rank=0,
world_size=num_workers,
timeout=timedelta(seconds=self._timeout_s))
ray.get(remote_pgs)
local_node_ip = ray.util.get_node_ip_address()
rank_dict = defaultdict(int)
self.local_worker.set_local_rank(local_rank=0)
rank_dict[local_node_ip] += 1
self.remote_worker_group._setup_local_rank(rank_dict)
local_node_ip = ray.util.get_node_ip_address()
rank_dict = defaultdict(int)
self.local_worker.set_local_rank(local_rank=0)
rank_dict[local_node_ip] += 1
self.remote_worker_group._setup_local_rank(rank_dict)
remote_operators = self.remote_worker_group._setup_operator()
self.local_worker.setup_operator()
ray.get(remote_operators)
remote_operators = self.remote_worker_group._setup_operator()
self.local_worker.setup_operator()
ray.get(remote_operators)
return True
except RayActorError:
return False
def apply_all_operators(self, fn):
remote_calls = self.remote_worker_group._apply_all_operators(fn)