mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[SGD] Worker Startup Fault Tolerance (#14724)
This commit is contained in:
parent
c30d5f445c
commit
47300d5a53
3 changed files with 188 additions and 61 deletions
|
@ -1,6 +1,8 @@
|
|||
from unittest.mock import patch
|
||||
import pytest
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -59,12 +61,12 @@ start_workers = TorchTrainer._start_workers
|
|||
|
||||
def gen_start_with_fail(num_fails):
|
||||
def start_with_fail(self, *args, **kwargs):
|
||||
start_workers(self, *args, **kwargs)
|
||||
success = start_workers(self, *args, **kwargs)
|
||||
fail = self._num_failures < num_fails
|
||||
if self.use_local:
|
||||
self.worker_group.remote_worker_group.should_fail = fail
|
||||
else:
|
||||
self.worker_group.should_fail = fail
|
||||
fail_worker_group = self.worker_group.remote_worker_group if \
|
||||
self.use_local else self.worker_group
|
||||
fail_worker_group.should_fail = fail
|
||||
return success
|
||||
|
||||
return start_with_fail
|
||||
|
||||
|
@ -178,6 +180,112 @@ def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811
|
|||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
|
||||
def test_fail_state(ray_start_2_cpus): # noqa: F811
|
||||
"""Tests if state of training with failure is same as training without."""
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
def init_hook():
|
||||
torch.manual_seed(0)
|
||||
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
timeout_s=5,
|
||||
initialization_hook=init_hook,
|
||||
num_workers=2)
|
||||
initial_state = trainer1.state_dict()
|
||||
trainer1.train()
|
||||
trainer1_state = trainer1.state_dict()
|
||||
assert trainer1_state != initial_state
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
timeout_s=5,
|
||||
initialization_hook=init_hook,
|
||||
num_workers=2)
|
||||
trainer2.load_state_dict(initial_state)
|
||||
trainer2.train()
|
||||
assert trainer2.state_dict() == trainer1_state
|
||||
trainer2.shutdown()
|
||||
|
||||
start_with_fail = gen_start_with_fail(1)
|
||||
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
|
||||
trainer3 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
timeout_s=5,
|
||||
initialization_hook=init_hook,
|
||||
num_workers=2)
|
||||
trainer3.load_state_dict(initial_state)
|
||||
trainer3.train()
|
||||
assert trainer3.state_dict() == trainer1_state
|
||||
trainer3.shutdown()
|
||||
|
||||
|
||||
def gen_start_with_startup_fail(num_fails):
|
||||
fail_start = gen_start_with_fail(num_fails)
|
||||
|
||||
def start_with_fail(self, *args, **kwargs):
|
||||
if hasattr(self, "worker_group"):
|
||||
# Fail during worker start just during the first training attempt.
|
||||
def _raise():
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
if not self.initialization_hook:
|
||||
self.initialization_hook = _raise
|
||||
else:
|
||||
self.initialization_hook = None
|
||||
return fail_start(self, *args, **kwargs)
|
||||
|
||||
return start_with_fail
|
||||
|
||||
|
||||
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
|
||||
def test_failure_during_resize(ray_start_2_cpus): # noqa: F811
|
||||
"""Tests if training succeeds even with failures during worker resizing."""
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
start_with_fail = gen_start_with_startup_fail(1)
|
||||
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
timeout_s=5,
|
||||
use_local=False,
|
||||
num_workers=2)
|
||||
trainer1.train()
|
||||
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -263,7 +263,11 @@ class TorchTrainer:
|
|||
"multi-node training, be sure to run `ray.init("
|
||||
"address='auto')` before instantiating the Trainer.")
|
||||
ray.init()
|
||||
self._start_workers(self.max_replicas)
|
||||
startup_success = self._start_workers(self.max_replicas)
|
||||
if not startup_success:
|
||||
raise RuntimeError("Worker startup failed. "
|
||||
"Are you sure you have enough resources to "
|
||||
"start the specified number of workers?")
|
||||
|
||||
def _configure_and_split_batch(self, num_workers):
|
||||
"""If sgd.utils.BATCH_SIZE is provided, split among workers."""
|
||||
|
@ -323,17 +327,17 @@ class TorchTrainer:
|
|||
# num_workers workers, this command will hang. Instead,
|
||||
# start_workers should take into account available resources when
|
||||
# determining how many workers to create.
|
||||
self.worker_group.start_workers(num_workers)
|
||||
return self.worker_group.start_workers(num_workers)
|
||||
|
||||
def _resize_worker_group(self, max_retries=10):
|
||||
def _resize_worker_group(self, state_dict, max_retries=10):
|
||||
"""Resizes the number of remote workers based on available resources.
|
||||
Total number of workers will never exceed `num_workers` amount.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The state dict to load to all workers.
|
||||
max_retries (int): How many times to attempt to resize workers
|
||||
before failing.
|
||||
"""
|
||||
state_dict = self.state_dict()
|
||||
old_workers = self.worker_group.num_workers
|
||||
self.worker_group.reset()
|
||||
|
||||
|
@ -342,7 +346,12 @@ class TorchTrainer:
|
|||
new_workers = self.worker_group.new_workers_size()
|
||||
if new_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_workers))
|
||||
startup_success = self._start_workers(int(new_workers))
|
||||
if not startup_success:
|
||||
logger.info(f"Worker startup failed. Retrying "
|
||||
f"{max_retries-i-1} more times.")
|
||||
self.worker_group.reset()
|
||||
continue
|
||||
self.load_state_dict(state_dict, blocking=True)
|
||||
if self.use_local and new_workers == 1 and old_workers > 1:
|
||||
# Major hack. If we go from LocalDistributedRunner to a
|
||||
|
@ -408,9 +417,10 @@ class TorchTrainer:
|
|||
assert isinstance(dataset, Dataset) is not None \
|
||||
or self.data_creator, \
|
||||
"Must specify either a data creator or a dataset"
|
||||
state_dict = self.state_dict()
|
||||
if self.worker_group.should_scale_up():
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_worker_group()
|
||||
self._resize_worker_group(state_dict)
|
||||
success, worker_stats = self.worker_group.train(
|
||||
num_steps=num_steps, profile=profile, info=info, dataset=dataset)
|
||||
# Fault handling
|
||||
|
@ -419,7 +429,7 @@ class TorchTrainer:
|
|||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_worker_group()
|
||||
self._resize_worker_group(state_dict)
|
||||
logger.info("Retrying training step with %d workers." %
|
||||
self.worker_group.num_workers)
|
||||
success, worker_stats = self.worker_group.train(
|
||||
|
|
|
@ -21,7 +21,7 @@ class WorkerGroupInterface:
|
|||
"""Manages a group of TorchRunner workers."""
|
||||
|
||||
def start_workers(self, num_workers):
|
||||
"""Start workers for training.
|
||||
"""Start workers for training. Returns if startup is successful.
|
||||
|
||||
This method has 4 steps.
|
||||
1. Creates `num_workers` TorchRunner objects, either all as remote
|
||||
|
@ -207,28 +207,33 @@ class RemoteWorkerGroup(WorkerGroupInterface):
|
|||
|
||||
def start_workers(self, num_workers):
|
||||
logger.debug(f"start_workers: Setting %d workers." % num_workers)
|
||||
if num_workers == 1:
|
||||
RemoteRunner = ray.remote(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu))(TorchRunner)
|
||||
self.remote_workers = [RemoteRunner.remote(**self._params)]
|
||||
ray.get(self.remote_workers[0].setup_operator.remote())
|
||||
else:
|
||||
self._init_dist_workers(num_workers)
|
||||
try:
|
||||
if num_workers == 1:
|
||||
RemoteRunner = ray.remote(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu))(TorchRunner)
|
||||
self.remote_workers = [RemoteRunner.remote(**self._params)]
|
||||
ray.get(self.remote_workers[0].setup_operator.remote())
|
||||
else:
|
||||
self._init_dist_workers(num_workers)
|
||||
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
|
||||
# Make sure to get the IP address of the rank 0 worker node.
|
||||
address = ray.get(self.remote_workers[0].setup_address.remote())
|
||||
# Make sure to get the IP address of the rank 0 worker node.
|
||||
address = ray.get(
|
||||
self.remote_workers[0].setup_address.remote())
|
||||
|
||||
ray.get(
|
||||
self._setup_process_group(
|
||||
address=address, world_size=num_workers))
|
||||
ray.get(
|
||||
self._setup_process_group(
|
||||
address=address, world_size=num_workers))
|
||||
|
||||
ray.get(self._setup_local_rank())
|
||||
ray.get(self._setup_local_rank())
|
||||
|
||||
ray.get(self._setup_operator())
|
||||
ray.get(self._setup_operator())
|
||||
return True
|
||||
except RayActorError:
|
||||
return False
|
||||
|
||||
def _apply_all_operators(self, fn):
|
||||
remote_calls = [
|
||||
|
@ -443,42 +448,46 @@ class LocalWorkerGroup(WorkerGroupInterface):
|
|||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
self.local_worker.setup_operator()
|
||||
return True
|
||||
else:
|
||||
try:
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu),
|
||||
**{
|
||||
**self._params,
|
||||
**self._dist_params
|
||||
})
|
||||
self.remote_worker_group._init_dist_workers(num_workers - 1)
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu),
|
||||
**{
|
||||
**self._params,
|
||||
**self._dist_params
|
||||
})
|
||||
self.remote_worker_group._init_dist_workers(num_workers - 1)
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
# Compute URL for initializing distributed PyTorch.
|
||||
address = setup_address()
|
||||
|
||||
# Compute URL for initializing distributed PyTorch.
|
||||
address = setup_address()
|
||||
remote_pgs = self.remote_worker_group._setup_process_group(
|
||||
address=address, world_size=num_workers, starting_rank=1)
|
||||
# Use the local worker as rank 0. Helps with debugging.
|
||||
self.local_worker.setup_process_group(
|
||||
url=address,
|
||||
world_rank=0,
|
||||
world_size=num_workers,
|
||||
timeout=timedelta(seconds=self._timeout_s))
|
||||
ray.get(remote_pgs)
|
||||
|
||||
remote_pgs = self.remote_worker_group._setup_process_group(
|
||||
address=address, world_size=num_workers, starting_rank=1)
|
||||
# Use the local worker as rank 0. This will help with debugging.
|
||||
self.local_worker.setup_process_group(
|
||||
url=address,
|
||||
world_rank=0,
|
||||
world_size=num_workers,
|
||||
timeout=timedelta(seconds=self._timeout_s))
|
||||
ray.get(remote_pgs)
|
||||
local_node_ip = ray.util.get_node_ip_address()
|
||||
rank_dict = defaultdict(int)
|
||||
self.local_worker.set_local_rank(local_rank=0)
|
||||
rank_dict[local_node_ip] += 1
|
||||
self.remote_worker_group._setup_local_rank(rank_dict)
|
||||
|
||||
local_node_ip = ray.util.get_node_ip_address()
|
||||
rank_dict = defaultdict(int)
|
||||
self.local_worker.set_local_rank(local_rank=0)
|
||||
rank_dict[local_node_ip] += 1
|
||||
self.remote_worker_group._setup_local_rank(rank_dict)
|
||||
|
||||
remote_operators = self.remote_worker_group._setup_operator()
|
||||
self.local_worker.setup_operator()
|
||||
ray.get(remote_operators)
|
||||
remote_operators = self.remote_worker_group._setup_operator()
|
||||
self.local_worker.setup_operator()
|
||||
ray.get(remote_operators)
|
||||
return True
|
||||
except RayActorError:
|
||||
return False
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
remote_calls = self.remote_worker_group._apply_all_operators(fn)
|
||||
|
|
Loading…
Add table
Reference in a new issue