mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[sgd] fault tolerance for pytorch + revamp documentation (#6465)
This commit is contained in:
parent
e5ad4e6f8d
commit
232be5a058
12 changed files with 365 additions and 60 deletions
|
@ -17,12 +17,13 @@ Ray is packaged with the following libraries for accelerating machine learning w
|
|||
|
||||
- `Tune`_: Scalable Hyperparameter Tuning
|
||||
- `RLlib`_: Scalable Reinforcement Learning
|
||||
- `Distributed Training <distributed_training.html>`__
|
||||
- `RaySGD`_: Distributed Training
|
||||
|
||||
|
||||
Star us on `on GitHub`_. You can also get started by visiting our `Tutorials <https://github.com/ray-project/tutorial>`_. For the latest wheels (nightlies), see the `installation page <installation.html>`__.
|
||||
|
||||
.. _`on GitHub`: https://github.com/ray-project/ray
|
||||
.. _`RaySGD`: raysgd/raysgd.html
|
||||
|
||||
|
||||
Quick Start
|
||||
|
@ -272,12 +273,16 @@ Getting Involved
|
|||
rllib-dev.rst
|
||||
rllib-package-ref.rst
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: -1
|
||||
:caption: RaySGD
|
||||
|
||||
raysgd/raysgd.rst
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: -1
|
||||
:caption: Experimental
|
||||
|
||||
distributed_training.rst
|
||||
tf_distributed_training.rst
|
||||
pandas_on_ray.rst
|
||||
projects.rst
|
||||
signals.rst
|
||||
|
|
22
doc/source/raysgd/raysgd.rst
Normal file
22
doc/source/raysgd/raysgd.rst
Normal file
|
@ -0,0 +1,22 @@
|
|||
RaySGD: Distributed Deep Learning
|
||||
=================================
|
||||
|
||||
.. image:: raysgdlogo.png
|
||||
:scale: 20%
|
||||
:align: center
|
||||
|
||||
RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around framework-native modules for data parallel training.
|
||||
|
||||
The main features are:
|
||||
|
||||
- Ease of use: Scale Pytorch's native ``DistributedDataParallel`` and TensorFlow's ``tf.distribute.MirroredStrategy`` without needing to monitor individual nodes.
|
||||
- Composibility: RaySGD is built on top of the Ray Actor API, enabling seamless integration with existing Ray applications such as RLlib, Tune, and Ray.Serve.
|
||||
- Scale up and down: Start on single CPU. Scale up to multi-node, multi-gpu by changing 2 lines of code.
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
||||
raysgd_pytorch.rst
|
||||
raysgd_tensorflow.rst
|
||||
raysgd_ft.rst
|
33
doc/source/raysgd/raysgd_ft.rst
Normal file
33
doc/source/raysgd/raysgd_ft.rst
Normal file
|
@ -0,0 +1,33 @@
|
|||
RaySGD Fault Tolerance
|
||||
======================
|
||||
|
||||
.. note:: Fault tolerance is currently only enabled for the PyTorchTrainer.
|
||||
|
||||
For distributed deep learning, jobs are often run on infrastructure where nodes can be pre-empted frequently (i.e., spot instances in the cloud). To overcome this, RaySGD provides **fault tolerance** features that enable training to continue regardless of node failures.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
trainer.train(max_retries=N)
|
||||
|
||||
|
||||
How does it work?
|
||||
-----------------
|
||||
|
||||
During each ``train`` method, each parallel worker iterates through the dataset, synchronizing gradients and parameters at each batch. These synchronization primitives can hang when one or more of the parallel workers becomes unresponsive (i.e., when a node is lost). To address this, we've implemented the following protocol.
|
||||
|
||||
1. If any worker node is lost, Ray will mark the training task as complete (``ray.wait`` will return).
|
||||
2. Ray will throw ``RayActorException`` when fetching the result for any worker, so the Trainer class will call ``ray.get`` on the "finished" training task.
|
||||
3. Upon catching this exception, the Trainer class will kill all of its workers.
|
||||
4. The Trainer will then detect the quantity of available resources (either CPUs or GPUs). It will then restart as many workers as it can, each resuming from the last checkpoint. Note that this may result in fewer workers than initially specified.
|
||||
5. If there are no available resources, the Trainer will apply an exponential backoff before retrying to create workers.
|
||||
6. If there are available resources and the Trainer has fewer workers than initially specified, then it will scale up its worker pool until it reaches the initially specified ``num_workers``.
|
||||
|
||||
Note that we assume the Trainer itself is not on a pre-emptible node. It is currently not possible to recover from a Trainer node failure.
|
||||
|
||||
Users can set ``checkpoint="auto"`` to always checkpoint the current model before executing a pass over the training dataset.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
trainer.train(max_retries=N, checkpoint="auto")
|
||||
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
Distributed Training (Experimental)
|
||||
===================================
|
||||
RaySGD Pytorch
|
||||
==============
|
||||
|
||||
.. warning:: This is still an experimental API and is subject to change in the near future.
|
||||
|
||||
Ray's ``PyTorchTrainer`` simplifies distributed model training for PyTorch. The ``PyTorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to execute training outside of Python.
|
||||
|
||||
|
@ -84,7 +86,7 @@ PyTorchTrainer Example
|
|||
|
||||
Below is an example of using Ray's PyTorchTrainer. Under the hood, ``PytorchTrainer`` will create *replicas* of your model (controlled by ``num_replicas``) which are each managed by a worker.
|
||||
|
||||
.. literalinclude:: ../../python/ray/experimental/sgd/examples/train_example.py
|
||||
.. literalinclude:: ../../../python/ray/experimental/sgd/examples/train_example.py
|
||||
:language: python
|
||||
:start-after: __torch_train_example__
|
||||
|
||||
|
@ -94,7 +96,7 @@ Hyperparameter Optimization on Distributed Pytorch
|
|||
|
||||
``PyTorchTrainer`` naturally integrates with Tune via the ``PyTorchTrainable`` interface. The same arguments to ``PyTorchTrainer`` should be passed into the ``tune.run(config=...)`` as shown below.
|
||||
|
||||
.. literalinclude:: ../../python/ray/experimental/sgd/examples/tune_example.py
|
||||
.. literalinclude:: ../../../python/ray/experimental/sgd/examples/tune_example.py
|
||||
:language: python
|
||||
:start-after: __torch_tune_example__
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
TF Distributed Training
|
||||
=======================
|
||||
RaySGD TensorFlow
|
||||
=================
|
||||
|
||||
Ray's ``TFTrainer`` simplifies distributed model training for Tensorflow. The ``TFTrainer`` is a wrapper around ``MultiWorkerMirroredStrategy`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes.
|
||||
RaySGD's ``TFTrainer`` simplifies distributed model training for Tensorflow. The ``TFTrainer`` is a wrapper around ``MultiWorkerMirroredStrategy`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes.
|
||||
|
||||
.. important:: This API has only been tested with TensorFlow2.0rc and is still highly experimental. Please file bug reports if you run into any - thanks!
|
||||
|
||||
|
@ -67,7 +67,7 @@ TFTrainer Example
|
|||
|
||||
Below is an example of using Ray's TFTrainer. Under the hood, ``TFTrainer`` will create *replicas* of your model (controlled by ``num_replicas``) which are each managed by a worker.
|
||||
|
||||
.. literalinclude:: ../../python/ray/experimental/sgd/examples/tensorflow_train_example.py
|
||||
.. literalinclude:: ../../../python/ray/experimental/sgd/examples/tensorflow_train_example.py
|
||||
:language: python
|
||||
|
||||
|
BIN
doc/source/raysgd/raysgdlogo.png
Normal file
BIN
doc/source/raysgd/raysgdlogo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 278 KiB |
|
@ -88,11 +88,18 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
|||
|
||||
def get_state(self):
|
||||
"""Returns the state of the runner."""
|
||||
# This is so that we create a duplicate of weights into CPU rather than
|
||||
# move the model weights entirely out of the GPU, so that we can
|
||||
# resume training while saving intermediate checkpoints.
|
||||
cpu_state_dicts = []
|
||||
for model in self.models:
|
||||
state_dict = model.module.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
state_dict[k] = v.cpu()
|
||||
cpu_state_dicts += [state_dict]
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"models": [
|
||||
model.module.cpu().state_dict() for model in self.models
|
||||
],
|
||||
"models": cpu_state_dicts,
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ def data_creator(batch_size, config):
|
|||
]))
|
||||
|
||||
# Create the dataloader
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(dataset)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
@ -238,8 +239,8 @@ def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
|||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
for i in range(5):
|
||||
stats = trainer.train()
|
||||
for i in range(10):
|
||||
stats = trainer.train(max_retries=3)
|
||||
print(stats)
|
||||
|
||||
return trainer
|
||||
|
|
|
@ -4,6 +4,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import logging
|
||||
import numbers
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
|
@ -15,6 +17,7 @@ from ray.experimental.sgd import utils
|
|||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
|
||||
|
||||
class PyTorchTrainer:
|
||||
|
@ -74,8 +77,12 @@ class PyTorchTrainer:
|
|||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.data_creator = data_creator
|
||||
self.train_function = train_function
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.validation_function = validation_function
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
self.optimizer_timer = utils.TimerStat(window_size=1)
|
||||
|
||||
|
@ -83,58 +90,69 @@ class PyTorchTrainer:
|
|||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
logger.info("Using {} as backend.".format(backend))
|
||||
self.backend = backend
|
||||
self.use_gpu = use_gpu
|
||||
self.batch_size = batch_size
|
||||
self.max_replicas = num_replicas
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
self._start_workers(self.max_replicas)
|
||||
|
||||
def _start_workers(self, num_replicas):
|
||||
logger.info(f"start_workers: Setting %d replicas." % num_replicas)
|
||||
if num_replicas == 1:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(use_gpu))(PyTorchRunner)
|
||||
num_cpus=1, num_gpus=int(self.use_gpu))(PyTorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=train_function,
|
||||
validation_function=validation_function,
|
||||
self.model_creator,
|
||||
self.data_creator,
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
config=self.config,
|
||||
batch_size=batch_size)
|
||||
batch_size=self.batch_size)
|
||||
]
|
||||
if initialization_hook:
|
||||
self.apply_all_workers(initialization_hook)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(self.workers[0].setup.remote())
|
||||
else:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(use_gpu))(DistributedPyTorchRunner)
|
||||
num_cpus=1,
|
||||
num_gpus=int(self.use_gpu))(DistributedPyTorchRunner)
|
||||
# Compute batch size per replica
|
||||
batch_size_per_replica = batch_size // num_replicas
|
||||
if batch_size % num_replicas > 0:
|
||||
batch_size_per_replica = self.batch_size // num_replicas
|
||||
if self.batch_size % num_replicas > 0:
|
||||
new_batch_size = batch_size_per_replica * num_replicas
|
||||
logger.warning(
|
||||
("Changing batch size from {old_batch_size} to "
|
||||
"{new_batch_size} to evenly distribute batches across "
|
||||
"{num_replicas} replicas.").format(
|
||||
old_batch_size=batch_size,
|
||||
old_batch_size=self.batch_size,
|
||||
new_batch_size=new_batch_size,
|
||||
num_replicas=num_replicas))
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
backend=backend,
|
||||
train_function=train_function,
|
||||
validation_function=validation_function,
|
||||
self.model_creator,
|
||||
self.data_creator,
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
backend=self.backend,
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
config=self.config,
|
||||
batch_size=batch_size_per_replica)
|
||||
for i in range(num_replicas)
|
||||
]
|
||||
if initialization_hook:
|
||||
self.apply_all_workers(initialization_hook)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
ip = ray.get(self.workers[0].get_node_ip.remote())
|
||||
|
@ -146,13 +164,51 @@ class PyTorchTrainer:
|
|||
for i, worker in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def train(self):
|
||||
def train(self, max_retries=10, checkpoint="auto"):
|
||||
"""Runs a training epoch.
|
||||
|
||||
Runs an average over all values returned from workers.
|
||||
Runs an average over all values returned from workers. Set
|
||||
`max_retries` to enable fault handling in case of instance preemption.
|
||||
|
||||
Args:
|
||||
max_retries (int): Must be non-negative. If set to N, will
|
||||
kill all current workers, query the Ray global state for
|
||||
total available resources, and re-launch up to the
|
||||
available resources. Behavior is not well-defined
|
||||
in case of shared cluster usage.
|
||||
checkpoint (str): Path to checkpoint to restore from if retrying.
|
||||
If max_retries is set and checkpoint == "auto", PyTorchTrainer
|
||||
will save a checkpoint before starting to train.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
if max_retries:
|
||||
if checkpoint == "auto":
|
||||
logger.debug("Retrying detected. Automatically checkpointing.")
|
||||
checkpoint = self.save(
|
||||
os.path.join(self.temp_dir, "tmp_checkpoint"))
|
||||
elif not checkpoint:
|
||||
raise ValueError("Cannot retry from empty checkpoint.")
|
||||
|
||||
if checkpoint and self._should_resize():
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
with self.optimizer_timer:
|
||||
worker_stats = ray.get([w.step.remote() for w in self.workers])
|
||||
success, worker_stats = self._train_step()
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
logger.info("Retrying training step with %d workers." % len(
|
||||
self.workers))
|
||||
success, worker_stats = self._train_step()
|
||||
if not success:
|
||||
raise RuntimeError("Training run failed.")
|
||||
|
||||
worker_stats = ray.get(worker_stats)
|
||||
|
||||
train_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
|
@ -163,6 +219,11 @@ class PyTorchTrainer:
|
|||
train_stats[stat_key] = worker_stats[0][stat_key]
|
||||
return train_stats
|
||||
|
||||
def _train_step(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
success = utils.check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
|
||||
|
||||
|
@ -211,11 +272,54 @@ class PyTorchTrainer:
|
|||
state_id = ray.put(state)
|
||||
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self, force=False):
|
||||
"""Shuts down workers and releases resources."""
|
||||
for worker in self.workers:
|
||||
worker.shutdown.remote()
|
||||
worker.__ray_terminate__.remote()
|
||||
if not force:
|
||||
worker.shutdown.remote()
|
||||
worker.__ray_terminate__.remote()
|
||||
else:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
worker.__ray_kill__()
|
||||
|
||||
self.workers = []
|
||||
|
||||
def _resize_workers(self, checkpoint, max_retries=10):
|
||||
# check available resources
|
||||
self.shutdown(force=True)
|
||||
assert checkpoint, "Cannot restore without checkpoint."
|
||||
|
||||
time.sleep(1)
|
||||
for i in range(max_retries):
|
||||
resources = ray.available_resources()
|
||||
new_workers = min(resources.get("CPU", 0), self.max_replicas)
|
||||
if self.use_gpu:
|
||||
new_workers = min(resources.get("GPU", 0), new_workers)
|
||||
if new_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_workers))
|
||||
self.restore(checkpoint)
|
||||
return
|
||||
else:
|
||||
delay = 2**i
|
||||
logger.info("Resources: {}".format(resources))
|
||||
logger.warning(
|
||||
"No new workers found. Retrying in %d sec." % delay)
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Exceeded max_retries for relaunching workers.")
|
||||
|
||||
def _should_resize(self):
|
||||
"""Returns True if past cooldown and exists resources to scale up."""
|
||||
worker_gap = self.max_replicas - len(self.workers)
|
||||
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
|
||||
if past_cooldown and worker_gap:
|
||||
resources = ray.available_resources()
|
||||
potential_workers = min(resources.get("CPU", 0), self.max_replicas)
|
||||
if self.use_gpu:
|
||||
potential_workers = min(
|
||||
resources.get("GPU", 0), potential_workers)
|
||||
return potential_workers > 0
|
||||
return False
|
||||
|
||||
|
||||
class PyTorchTrainable(Trainable):
|
||||
|
|
|
@ -1,21 +1,26 @@
|
|||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
|
||||
from ray.experimental.sgd.pytorch.utils import train
|
||||
from ray.experimental.sgd.utils import check_for_failure
|
||||
|
||||
from ray.experimental.sgd.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
trainer = PyTorchTrainer(
|
||||
model_creator,
|
||||
|
@ -36,8 +41,8 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
|||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
def custom_train(models, dataloader, criterion, optimizers, config):
|
||||
result = {}
|
||||
|
@ -94,15 +99,15 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
|||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
config = {
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"loss_creator": tune.function(lambda config: nn.MSELoss()),
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.MSELoss(),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": False,
|
||||
"batch_size": 512,
|
||||
|
@ -127,8 +132,8 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
|||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
|
@ -164,3 +169,101 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
|||
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
if self._num_failures < 3:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer1.train(max_retries=1)
|
||||
|
||||
|
||||
def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
if self._num_failures < 1:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
|
||||
@ray.remote
|
||||
def try_test():
|
||||
import time
|
||||
time.sleep(100)
|
||||
|
||||
try_test.remote()
|
||||
trainer1.train(max_retries=1)
|
||||
assert len(trainer1.workers) == 1
|
||||
|
||||
|
||||
def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
if self._num_failures < 2:
|
||||
time.sleep(1)
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
|
||||
trainer1.train(max_retries=2)
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
from contextlib import closing
|
||||
import logging
|
||||
import numpy as np
|
||||
import socket
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimerStat:
|
||||
"""A running stat for conveniently logging the duration of a code block.
|
||||
|
@ -121,3 +127,25 @@ class AverageMeter:
|
|||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def check_for_failure(remote_values):
|
||||
"""Checks remote values for any that returned and failed.
|
||||
|
||||
Args:
|
||||
remote_values (list): List of object IDs representing functions
|
||||
that may fail in the middle of execution. For example, running
|
||||
a SGD training loop in multiple parallel actor calls.
|
||||
|
||||
Returns:
|
||||
Bool for success in executing given remote tasks.
|
||||
"""
|
||||
unfinished = remote_values
|
||||
try:
|
||||
while len(unfinished) > 0:
|
||||
finished, unfinished = ray.wait(unfinished)
|
||||
finished = ray.get(finished)
|
||||
return True
|
||||
except RayActorError as exc:
|
||||
logger.exception(str(exc))
|
||||
return False
|
||||
|
|
|
@ -279,7 +279,7 @@ if __name__ == "__main__":
|
|||
MemNNModel,
|
||||
name="pbt_babi_memnn",
|
||||
scheduler=pbt,
|
||||
stop={"training_iteration": 20 if args.smoke_test else 100},
|
||||
stop={"training_iteration": 10 if args.smoke_test else 100},
|
||||
num_samples=4,
|
||||
config={
|
||||
"batch_size": 32,
|
||||
|
|
Loading…
Add table
Reference in a new issue