The RaySGD ``PyTorchTrainer`` simplifies distributed model training for PyTorch. The ``PyTorchTrainer`` is a wrapper around ``torch.distributed.launch`` with a Python API to easily incorporate distributed training into a larger Python application, as opposed to needing to execute training outside of Python. For end to end examples, see :ref:`raysgd-pytorch-example`.
The ``PyTorchTrainer`` can be constructed with functions that wrap components of the training script. Specifically, it needs constructors for the Model, Data, Optimizer, and Loss to create replicated copies across different devices and machines.
Now that the trainer is constructed, you'll naturally want to train the model.
..code-block:: python
trainer.train()
This takes one pass over the training data.
To run the model on the validation data passed in by the ``data_creator``, you can simply call:
..code-block:: python
trainer.validate()
You can customize the exact function that is called by using a customized training function (see :ref:`raysgd-custom-training`).
Shutting down training
----------------------
After training, you may want to reappropriate the Ray cluster. To release Ray resources obtained by the trainer:
..code-block:: python
trainer.shutdown()
..note:: Be sure to call ``save`` or ``get_model`` before shutting down.
Initialization Functions
------------------------
You may want to run some initializers on each worker when they are started. This may be something like setting an environment variable or downloading some data. You can do this via the ``initialization_hook`` parameter:
# Need this for avoiding a connection restart issue
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
os.environ["NCCL_LL_THRESHOLD"] = "0"
os.environ["NCCL_DEBUG"] = "INFO"
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator,
loss_creator=nn.MSELoss,
initialization_hook=initialization_hook,
config={"lr": 0.001}
num_replicas=100,
use_gpu=True)
Save and Load
-------------
If you want to save or reload the training procedure, you can use ``trainer.save``
and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` calls. This should work across a distributed cluster even without a NFS because it takes advantage of Ray's distributed object store.
The trained torch model can be extracted for use within the same Python program with ``trainer.get_model()``. This will load the state dictionary of the model(s).
..code-block::
trainer.train()
model = trainer.get_model()
Distributed Multi-node Training
-------------------------------
You can scale out your training onto multiple nodes without making any modifications to your training code. To train across a cluster, simply make sure that the Ray cluster is started.
You can start a Ray cluster `via the Ray cluster launcher <autoscaling.html>`_ or `manually <using-ray-on-a-cluster.html>`_.
..code-block:: bash
ray up CLUSTER.yaml
python train.py --address="auto"
Then, you'll be able to scale up the number of workers seamlessly across multiple nodes:
For distributed deep learning, jobs are often run on infrastructure where nodes can be pre-empted frequently (i.e., spot instances in the cloud). To overcome this, RaySGD provides **fault tolerance** features that enable training to continue regardless of node failures.
During each ``train`` method, each parallel worker iterates through the dataset, synchronizing gradients and parameters at each batch. These synchronization primitives can hang when one or more of the parallel workers becomes unresponsive (i.e., when a node is lost). To address this, we've implemented the following protocol.
1. If any worker node is lost, Ray will mark the training task as complete (``ray.wait`` will return).
2. Ray will throw ``RayActorException`` when fetching the result for any worker, so the Trainer class will call ``ray.get`` on the "finished" training task.
3. Upon catching this exception, the Trainer class will kill all of its workers.
4. The Trainer will then detect the quantity of available resources (either CPUs or GPUs). It will then restart as many workers as it can, each resuming from the last checkpoint. Note that this may result in fewer workers than initially specified.
5. If there are no available resources, the Trainer will apply an exponential backoff before retrying to create workers.
6. If there are available resources and the Trainer has fewer workers than initially specified, then it will scale up its worker pool until it reaches the initially specified ``num_workers``.
Note that we assume the Trainer itself is not on a pre-emptible node. It is currently not possible to recover from a Trainer node failure.
Users can set ``checkpoint="auto"`` to always checkpoint the current model before executing a pass over the training dataset.
``PyTorchTrainer`` naturally integrates with Tune via the ``PyTorchTrainable`` interface. The same arguments to ``PyTorchTrainer`` should be passed into the ``tune.run(config=...)`` as shown below.
In certain scenarios such as training GANs, you may want to use multiple models in the training loop. You can do this in the ``PyTorchTrainer`` by allowing the ``model_creator`` and the ``optimizer_creator`` to return multiple values.
If multiple models are returned, you will need to provide a custom training function (and custom validation function if you plan to call ``validate``).
You can see the `DCGAN script <https://github.com/ray-project/ray/blob/master/python/ray/experimental/sgd/pytorch/examples/dcgan.py>`_ for an end-to-end example.
``PyTorchTrainer`` allows you to run a custom training and validation step in parallel on each worker, providing a flexibility similar to using PyTorch natively. This is done via the ``train_function`` and ``validation_function`` parameters.
Note that this is needed if the model creator returns multiple models.