This document provides a walkthrough of adapting the `Fairseq library <https://github.com/pytorch/fairseq>`__ to perform fault-tolerant distributed training on AWS.
As an example, we use the WikiText-103 dataset to pretrain the RoBERTa model following `this tutorial <https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md>`__. The pipeline and configurations in this document will work for other models supported by Fairseq, such as sequence-to-sequence machine translation models.
To use Ray cluster launcher on AWS, install boto (``pip install boto3``) and configure your AWS credentials in ``~/.aws/credentials`` as described on the :ref:`Automatic Cluster Setup page <cluster-cloud>`.
In the example config file, we use an ``m5.xlarge`` on-demand instance as the head node, and use ``p3.2xlarge`` GPU spot instances as the worker nodes. We set the minimal number of workers to 1 and maximum workers to 2 in the config, which can be modified according to your own demand.
ray rsync-up lm-cluster.yaml PATH/TO/LM '~/efs/lm'
Preprocessing Data
------------------
Once the cluster is started, you can then SSH into the head node using ``ray attach lm-cluster.yaml`` and download or preprocess the data on EFS for training. We can run ``preprocess.sh`` (`code <https://github.com/ray-project/ray/tree/master/doc/examples/lm/preprocess.sh>`_) to do this, which adapts instructions from `the RoBERTa tutorial <https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md>`__.
Training
--------
We provide ``ray_train.py`` (`code <https://github.com/ray-project/ray/tree/master/doc/examples/lm/ray_train.py>`__) as an entrypoint to the Fairseq library. Since we are training the model on spot instances, we provide fault-tolerance in ``ray_train.py`` by checkpointing and restarting when a node fails. The code will also check whether there are new resources available after checkpointing. If so, the program will make use of them by restarting and resizing.
Two main components of ``ray_train.py`` are a ``RayDistributedActor`` class and a function ``run_fault_tolerant_loop()``. The ``RayDistributedActor`` sets proper arguments for different ray actor processes, adds a checkpoint hook to enable the process to make use of new available GPUs, and calls the ``main`` of Fairseq:
# Start the remote processes, and check whether their are any process
# fails. If so, restart all the processes.
unfinished = [
worker.run.remote(address, i, args)
for i, worker in enumerate(workers)
]
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)
finished = ray.get(finished)
retry = False
except Exception as inst:
print("Ray restart because following error occurs:")
print(inst)
retry = True
ray.shutdown()
In ``ray_train.py``, we also define a set of helper functions. ``add_ray_args()`` adds Ray and fault-tolerant training related arguments to the argument parser:
..code-block:: python
def add_ray_args(parser):
"""Add ray and fault-tolerance related parser arguments to the parser."""
group = parser.add_argument_group("Ray related arguments")
help="fix the actual batch size (max_sentences * update_freq "
"* n_GPUs) to be the fixed input values by adjusting update_freq "
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
"epoch i; all epochs >N are fixed to B_N")
return group
``set_num_resources()`` sets the distributed world size to be the number of resources. Also if we want to use GPUs but the current number of GPUs is 0, the function will wait until there is GPU available:
..code-block:: python
def set_num_resources(args):
"""Get the number of resources and set the corresponding fields."""
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
(args.distributed_world_size, args.max_sentences,
repr(args.update_freq)))
To start training, run `following commands <https://github.com/ray-project/ray/tree/master/doc/examples/lm/ray_train.sh>`__ (``ray_train.sh``) on the head machine:
..code-block:: bash
cd ~/efs/lm
TOTAL_UPDATES=125000 # Total number of training steps
WARMUP_UPDATES=10000 # Warmup the learning rate over this many updates
PEAK_LR=0.0005 # Peak learning rate, adjust as needed
``SAVE_INTERVAL_UPDATES`` controls how often to save a checkpoint, which can be tuned based on the `stability of chosen instances <https://aws.amazon.com/ec2/spot/instance-advisor/>`__. ``FIX_BATCH_SIZE`` controls the total batch size to be a roughly fixed number.
Helpful Ray Commands
--------------------
To let Ray automatically stop the cluster after the training finished, you can download the ``ray_train.sh`` to ``~/efs`` of the remote machine, and run the following command on your local machine:
..code-block:: bash
ray exec --stop lm-cluster.yaml 'bash $HOME/efs/lm/ray_train.sh'
or run the following command on the remote head node:
..code-block:: bash
ray exec --stop ~/ray_bootstrap_config.yaml 'bash $HOME/efs/lm/ray_train.sh'
To test the fault-tolerance, you can run the following command on your local machine to randomly kill one node: