The ``TorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to wrap your training code in bash scripts.
..tip:: If you want to leverage multi-node data parallel training with PyTorch while using RayTune *without* using RaySGD, check out the :ref:`Tune PyTorch user guide <tune-pytorch-cifar>` and Tune's :ref:`distributed pytorch integrations <tune-ddp-doc>`.
The :ref:`ref-torch-trainer` can be constructed from a custom :ref:`ref-torch-operator` subclass that defines training components like the model, data, optimizer, loss, and ``lr_scheduler``. These components are all automatically replicated across different machines and devices so that training can be executed in parallel.
..warning:: You should call ``self.register(...)`` and ``self.register_data(...)`` inside the ``setup`` method of your custom ``TrainingOperator`` to register the necessary training components with Ray SGD.
Each ``train`` call makes one pass over the training data (trains on 1 epoch), and each ``validate`` call runs the model on the validation data.
Override training and validation methods in your Training Operator (:ref:`raysgd-custom-training`) to calculate custom metrics or customize the training/validation process.
..tip:: Setting the batch size: Using a provided ``ray.util.sgd.utils.BATCH_SIZE`` variable, you can provide a global batch size that will be divided among all workers automatically.
If you would like to implement custom training and validation logic, you can do so by overriding the appropiate methods inside your :ref:`ref-torch-operator` subclass.
For both training and validation, there are two granularities that you can provide customization - per epoch and per batch. These correspond to ``train_batch``,
``train_epoch``, ``validate``, and ``validate_batch``. Other useful methods to override include ``state_dict`` and ``load_state_dict``. You can use these
to save and load additional state for your custom ``TrainingOperator``.
See the `DCGAN example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`__ for an end to end example. It constructs two models and two optimizers and uses a custom training operator to provide a non-standard training loop.
If you want to use a custom wrapper for distributed training or if you want to wrap in DistributedDataParallel yourself, you can do so by setting ``TorchTrainer(wrap_ddp=False)``.
In previous versions of Ray, *creator functions* (``model_creator``, ``optimizer_creator``, etc.) were necessary to setup the training components.
These creator functions are no longer used and instead training component setup should be specified inside the ``setup`` method of a ``TrainingOperator`` subclass.
However, if you have these creator functions already and do not want to change your code, you can easily use these creator functions to create a custom ``TrainingOperator``.
Use the ``initialization_hook`` parameter to initialize state on each worker process when they are started. This is useful when setting an environment variable:
If you want to save or reload the training procedure, you can use ``trainer.save``
and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` calls. This should work across a distributed cluster even without a NFS because it takes advantage of Ray's distributed object store.
The trained torch model can be extracted for use within the same Python program with ``trainer.get_model()``. This will load the state dictionary of the model(s).
The output for ``trainer.train()`` and ``trainer.validate()`` are first collected on a per-batch basis. These results are then averaged: first across each batch in the epoch, and then across all workers.
By default, the output of ``train`` contains the following:
..code-block:: python
# Total number of samples trained on in this epoch.
num_samples
# Current training epoch.
epoch
# Number of batches trained on in this epoch averaged across all workers.
batch_count
# Training loss averaged across all batches on all workers.
train_loss
# Training loss for the last batch in epoch averaged across all workers.
last_train_loss
And for ``validate``:
..code-block:: python
# Total number of samples validated on.
num_samples
# Number of batches validated on averaged across all workers.
batch_count
# Validation loss averaged across all batches on all workers.
val_loss
# Validation loss for last batch averaged across all workers.
last_val_loss
# Validation accuracy for last batch averaged across all workers.
val_accuracy
# Validation accuracy for last batch averaged across all workers.
last_val_accuracy
If ``train`` or ``validate`` are run with ``reduce_results=False``, results are not averaged across workers and a list of results for each worker is returned.
If run with ``profile=True``, timing stats for a single worker is returned alongside the results above.
To add additional metrics to return you should implement your own custom training operator (:ref:`raysgd-custom-training`).
If overriding ``train_batch`` or ``validate_batch``, the result outputs are automatically averaged across all batches, and the results for the last batch are automatically returned.
If overriding ``train_epoch`` or ``validate`` you may find ``ray.util.sgd.utils.AverageMeterCollection`` (:ref:`ref-utils`) useful to handle this averaging.
You can enable mixed precision training for PyTorch with the ``use_fp16`` flag. This automatically converts the model(s) and optimizer(s) to train using mixed-precision.
By default, `native mixed precision training <https://pytorch.org/docs/stable/amp.html>`_ will be used. This requires PyTorch>=1.6. If you are using an older version of PyTorch, you can alternatively use the ``Apex`` library. ``Apex`` is a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. It can be installed from `the NVIDIA/Apex repository <https://github.com/NVIDIA/apex#quick-start>`_.
When ``use_fp16=True`` and native mixed precision is not available, ``Apex`` will be used instead. If neither native support nor ``Apex`` are available, an exception will be raised.
``Apex`` can be forced to be used with ``use_fp16="apex"``.
When ``use_fp16=True``, you should not manually cast your model or data to ``.half()``.
**Native**:
The flag informs the Trainer to wrap model forward calls in ``torch.cuda.amp.autocast()`` and to scale the loss with ``torch.cuda.amp.GradScaler()``.
**Apex**:
The flag informs the Trainer to call ``amp.initialize`` on the created models and optimizers and optimize using the scaled loss: ``amp.scale_loss(loss, optimizer)``.
To specify particular parameters for ``amp.initialize``, you can use the ``apex_args`` field when calling `self.register` in your `TrainingOperator`. Valid arguments can be found in the `Apex documentation <https://nvidia.github.io/apex/amp.html#apex.amp.initialize>`_:
For distributed deep learning, jobs are often run on infrastructure where nodes can be pre-empted frequently (i.e., spot instances in the cloud). To overcome this, RaySGD provides **fault tolerance** features that enable training to continue regardless of node failures.
During each ``train`` method, each parallel worker iterates through the iterable, synchronizing gradients and parameters at each batch. These synchronization primitives can hang when one or more of the parallel workers becomes unresponsive (i.e., when a node is lost). To address this, we've implemented the following protocol.
1. If any worker node is lost, Ray will mark the training task as complete (``ray.wait`` will return).
2. Ray will throw ``RayActorException`` when fetching the result for any worker, so the Trainer class will call ``ray.get`` on the "finished" training task.
3. Upon catching this exception, the Trainer class will kill all of its workers.
4. The Trainer will then detect the quantity of available resources (either CPUs or GPUs). It will then restart as many workers as it can, each resuming from the last checkpoint. Note that this may result in fewer workers than initially specified.
5. If there are no available resources, the Trainer will apply an exponential backoff before retrying to create workers.
6. If there are available resources and the Trainer has fewer workers than initially specified, then it will scale up its worker pool until it reaches the initially specified ``num_workers``.
Note that we assume the Trainer itself is not on a pre-emptible node. To allow the entire Trainer to recover from failure, you must use Tune to execute the training.
In certain scenarios, such as training GANs, you may want to use multiple models in the training loop. You can do this by registering multiple models, optimizers, or schedulers in the ``setup`` method of ``TrainingOperator``. You must implement custom training and validation (:ref:`raysgd-custom-training`) to train across multiple models.
You can see the `DCGAN script <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/dcgan.py>`_ for an end-to-end example.
You can see more details in the `benchmarking README <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/benchmarks/README.rst>`_.
DISCLAIMER: RaySGD does not provide any custom communication primitives. If you see any performance issues, you may need to file them on the PyTorch github repository.
-`Semantic Segmentation example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py>`__
-`Huggingface Transformer GLUE fine tuning example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/transformers/transformers_example.py>`__
Fine-tuning a pre-trained Transformer model on GLUE tasks. Based off of the `huggingface/transformers <https://github.com/huggingface/transformers/blob/master/examples/>`_``run_glue.py`` example.
-`ImageNet Models example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/image_models/train.py>`__