From 47300d5a535fdb2c8b8c5363775ebea2741e5c70 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Thu, 18 Mar 2021 22:53:56 -0700 Subject: [PATCH] [SGD] Worker Startup Fault Tolerance (#14724) --- .../ray/util/sgd/tests/test_torch_failure.py | 118 +++++++++++++++++- python/ray/util/sgd/torch/torch_trainer.py | 24 ++-- python/ray/util/sgd/torch/worker_group.py | 107 ++++++++-------- 3 files changed, 188 insertions(+), 61 deletions(-) diff --git a/python/ray/util/sgd/tests/test_torch_failure.py b/python/ray/util/sgd/tests/test_torch_failure.py index 4f48daaec..00b91cd27 100644 --- a/python/ray/util/sgd/tests/test_torch_failure.py +++ b/python/ray/util/sgd/tests/test_torch_failure.py @@ -1,6 +1,8 @@ from unittest.mock import patch import pytest import time + +import torch import torch.nn as nn import torch.distributed as dist from torch.utils.data import DataLoader @@ -59,12 +61,12 @@ start_workers = TorchTrainer._start_workers def gen_start_with_fail(num_fails): def start_with_fail(self, *args, **kwargs): - start_workers(self, *args, **kwargs) + success = start_workers(self, *args, **kwargs) fail = self._num_failures < num_fails - if self.use_local: - self.worker_group.remote_worker_group.should_fail = fail - else: - self.worker_group.should_fail = fail + fail_worker_group = self.worker_group.remote_worker_group if \ + self.use_local else self.worker_group + fail_worker_group.should_fail = fail + return success return start_with_fail @@ -178,6 +180,112 @@ def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811 trainer1.shutdown(force=True) +@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail) +def test_fail_state(ray_start_2_cpus): # noqa: F811 + """Tests if state of training with failure is same as training without.""" + if not dist.is_available(): + return + + torch.manual_seed(0) + + def single_loader(config): + dataset = LinearDataset(2, 5, size=1000000) + return DataLoader(dataset, batch_size=config.get("batch_size", 32)) + + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) + + def init_hook(): + torch.manual_seed(0) + + trainer1 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + timeout_s=5, + initialization_hook=init_hook, + num_workers=2) + initial_state = trainer1.state_dict() + trainer1.train() + trainer1_state = trainer1.state_dict() + assert trainer1_state != initial_state + trainer1.shutdown() + + trainer2 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + timeout_s=5, + initialization_hook=init_hook, + num_workers=2) + trainer2.load_state_dict(initial_state) + trainer2.train() + assert trainer2.state_dict() == trainer1_state + trainer2.shutdown() + + start_with_fail = gen_start_with_fail(1) + with patch.object(TorchTrainer, "_start_workers", start_with_fail): + trainer3 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + timeout_s=5, + initialization_hook=init_hook, + num_workers=2) + trainer3.load_state_dict(initial_state) + trainer3.train() + assert trainer3.state_dict() == trainer1_state + trainer3.shutdown() + + +def gen_start_with_startup_fail(num_fails): + fail_start = gen_start_with_fail(num_fails) + + def start_with_fail(self, *args, **kwargs): + if hasattr(self, "worker_group"): + # Fail during worker start just during the first training attempt. + def _raise(): + import sys + sys.exit(1) + + if not self.initialization_hook: + self.initialization_hook = _raise + else: + self.initialization_hook = None + return fail_start(self, *args, **kwargs) + + return start_with_fail + + +@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail) +def test_failure_during_resize(ray_start_2_cpus): # noqa: F811 + """Tests if training succeeds even with failures during worker resizing.""" + if not dist.is_available(): + return + + def single_loader(config): + dataset = LinearDataset(2, 5, size=1000000) + return DataLoader(dataset, batch_size=config.get("batch_size", 32)) + + TestOperator = TrainingOperator.from_creators( + model_creator, + optimizer_creator, + single_loader, + loss_creator=lambda config: nn.MSELoss()) + + start_with_fail = gen_start_with_startup_fail(1) + with patch.object(TorchTrainer, "_start_workers", start_with_fail): + trainer1 = TorchTrainer( + training_operator_cls=TestOperator, + config={"batch_size": 100000}, + timeout_s=5, + use_local=False, + num_workers=2) + trainer1.train() + + trainer1.shutdown() + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 149c510ae..3b270a46e 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -263,7 +263,11 @@ class TorchTrainer: "multi-node training, be sure to run `ray.init(" "address='auto')` before instantiating the Trainer.") ray.init() - self._start_workers(self.max_replicas) + startup_success = self._start_workers(self.max_replicas) + if not startup_success: + raise RuntimeError("Worker startup failed. " + "Are you sure you have enough resources to " + "start the specified number of workers?") def _configure_and_split_batch(self, num_workers): """If sgd.utils.BATCH_SIZE is provided, split among workers.""" @@ -323,17 +327,17 @@ class TorchTrainer: # num_workers workers, this command will hang. Instead, # start_workers should take into account available resources when # determining how many workers to create. - self.worker_group.start_workers(num_workers) + return self.worker_group.start_workers(num_workers) - def _resize_worker_group(self, max_retries=10): + def _resize_worker_group(self, state_dict, max_retries=10): """Resizes the number of remote workers based on available resources. Total number of workers will never exceed `num_workers` amount. Args: + state_dict (dict): The state dict to load to all workers. max_retries (int): How many times to attempt to resize workers before failing. """ - state_dict = self.state_dict() old_workers = self.worker_group.num_workers self.worker_group.reset() @@ -342,7 +346,12 @@ class TorchTrainer: new_workers = self.worker_group.new_workers_size() if new_workers: self._last_resize = time.time() - self._start_workers(int(new_workers)) + startup_success = self._start_workers(int(new_workers)) + if not startup_success: + logger.info(f"Worker startup failed. Retrying " + f"{max_retries-i-1} more times.") + self.worker_group.reset() + continue self.load_state_dict(state_dict, blocking=True) if self.use_local and new_workers == 1 and old_workers > 1: # Major hack. If we go from LocalDistributedRunner to a @@ -408,9 +417,10 @@ class TorchTrainer: assert isinstance(dataset, Dataset) is not None \ or self.data_creator, \ "Must specify either a data creator or a dataset" + state_dict = self.state_dict() if self.worker_group.should_scale_up(): logger.info("Resize opportunity detected. Attempting to scale up.") - self._resize_worker_group() + self._resize_worker_group(state_dict) success, worker_stats = self.worker_group.train( num_steps=num_steps, profile=profile, info=info, dataset=dataset) # Fault handling @@ -419,7 +429,7 @@ class TorchTrainer: break else: self._num_failures += 1 - self._resize_worker_group() + self._resize_worker_group(state_dict) logger.info("Retrying training step with %d workers." % self.worker_group.num_workers) success, worker_stats = self.worker_group.train( diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py index ee8d94b64..c733f6149 100644 --- a/python/ray/util/sgd/torch/worker_group.py +++ b/python/ray/util/sgd/torch/worker_group.py @@ -21,7 +21,7 @@ class WorkerGroupInterface: """Manages a group of TorchRunner workers.""" def start_workers(self, num_workers): - """Start workers for training. + """Start workers for training. Returns if startup is successful. This method has 4 steps. 1. Creates `num_workers` TorchRunner objects, either all as remote @@ -207,28 +207,33 @@ class RemoteWorkerGroup(WorkerGroupInterface): def start_workers(self, num_workers): logger.debug(f"start_workers: Setting %d workers." % num_workers) - if num_workers == 1: - RemoteRunner = ray.remote( - num_cpus=self._num_cpus_per_worker, - num_gpus=int(self._use_gpu))(TorchRunner) - self.remote_workers = [RemoteRunner.remote(**self._params)] - ray.get(self.remote_workers[0].setup_operator.remote()) - else: - self._init_dist_workers(num_workers) + try: + if num_workers == 1: + RemoteRunner = ray.remote( + num_cpus=self._num_cpus_per_worker, + num_gpus=int(self._use_gpu))(TorchRunner) + self.remote_workers = [RemoteRunner.remote(**self._params)] + ray.get(self.remote_workers[0].setup_operator.remote()) + else: + self._init_dist_workers(num_workers) - if self._initialization_hook: - self.apply_all_workers(self._initialization_hook) + if self._initialization_hook: + self.apply_all_workers(self._initialization_hook) - # Make sure to get the IP address of the rank 0 worker node. - address = ray.get(self.remote_workers[0].setup_address.remote()) + # Make sure to get the IP address of the rank 0 worker node. + address = ray.get( + self.remote_workers[0].setup_address.remote()) - ray.get( - self._setup_process_group( - address=address, world_size=num_workers)) + ray.get( + self._setup_process_group( + address=address, world_size=num_workers)) - ray.get(self._setup_local_rank()) + ray.get(self._setup_local_rank()) - ray.get(self._setup_operator()) + ray.get(self._setup_operator()) + return True + except RayActorError: + return False def _apply_all_operators(self, fn): remote_calls = [ @@ -443,42 +448,46 @@ class LocalWorkerGroup(WorkerGroupInterface): if self._initialization_hook: self.apply_all_workers(self._initialization_hook) self.local_worker.setup_operator() + return True else: + try: + # Start local worker + self.local_worker = LocalDistributedRunner( + num_cpus=self._num_cpus_per_worker, + num_gpus=int(self._use_gpu), + **{ + **self._params, + **self._dist_params + }) + self.remote_worker_group._init_dist_workers(num_workers - 1) + if self._initialization_hook: + self.apply_all_workers(self._initialization_hook) - # Start local worker - self.local_worker = LocalDistributedRunner( - num_cpus=self._num_cpus_per_worker, - num_gpus=int(self._use_gpu), - **{ - **self._params, - **self._dist_params - }) - self.remote_worker_group._init_dist_workers(num_workers - 1) - if self._initialization_hook: - self.apply_all_workers(self._initialization_hook) + # Compute URL for initializing distributed PyTorch. + address = setup_address() - # Compute URL for initializing distributed PyTorch. - address = setup_address() + remote_pgs = self.remote_worker_group._setup_process_group( + address=address, world_size=num_workers, starting_rank=1) + # Use the local worker as rank 0. Helps with debugging. + self.local_worker.setup_process_group( + url=address, + world_rank=0, + world_size=num_workers, + timeout=timedelta(seconds=self._timeout_s)) + ray.get(remote_pgs) - remote_pgs = self.remote_worker_group._setup_process_group( - address=address, world_size=num_workers, starting_rank=1) - # Use the local worker as rank 0. This will help with debugging. - self.local_worker.setup_process_group( - url=address, - world_rank=0, - world_size=num_workers, - timeout=timedelta(seconds=self._timeout_s)) - ray.get(remote_pgs) + local_node_ip = ray.util.get_node_ip_address() + rank_dict = defaultdict(int) + self.local_worker.set_local_rank(local_rank=0) + rank_dict[local_node_ip] += 1 + self.remote_worker_group._setup_local_rank(rank_dict) - local_node_ip = ray.util.get_node_ip_address() - rank_dict = defaultdict(int) - self.local_worker.set_local_rank(local_rank=0) - rank_dict[local_node_ip] += 1 - self.remote_worker_group._setup_local_rank(rank_dict) - - remote_operators = self.remote_worker_group._setup_operator() - self.local_worker.setup_operator() - ray.get(remote_operators) + remote_operators = self.remote_worker_group._setup_operator() + self.local_worker.setup_operator() + ray.get(remote_operators) + return True + except RayActorError: + return False def apply_all_operators(self, fn): remote_calls = self.remote_worker_group._apply_all_operators(fn)