ray/doc/examples/lm/ray_train.py
Ian Rodney eb12033612
[Code Cleanup] Switch to use ray.util.get_node_ip_address() (#14741)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
2021-03-18 13:10:57 -07:00

181 lines
6.2 KiB
Python

#!/usr/bin/env python3 -u
import math
import copy
import socket
import time
import ray
import fairseq
from fairseq import options
from fairseq_cli.train import main
from contextlib import closing
_original_save_checkpoint = fairseq.checkpoint_utils.save_checkpoint
class RayDistributedActor:
"""Actor to perform distributed training."""
def run(self, url, world_rank, args):
"""Runs the fairseq training.
We set args for different ray actors for communication,
add a checkpoint hook, and call the main function of fairseq.
"""
# Set the init_method and rank of the process for distributed training.
print("Ray worker at {url} rank {rank}".format(
url=url, rank=world_rank))
self.url = url
self.world_rank = world_rank
args.distributed_rank = world_rank
args.distributed_init_method = url
# Add a checkpoint hook to make use of new resources.
self.add_checkpoint_hook(args)
# Call the original main function of fairseq.
main(args, init_distributed=(args.distributed_world_size > 1))
def add_checkpoint_hook(self, args):
"""Add a hook to the original save_checkpoint function.
This checks if there are new computational resources available.
If so, raise exception to restart the training process and
make use of the new resources.
"""
if args.cpu:
original_n_cpus = args.distributed_world_size
def _new_save_checkpoint(*args, **kwargs):
_original_save_checkpoint(*args, **kwargs)
n_cpus = int(ray.cluster_resources()["CPU"])
if n_cpus > original_n_cpus:
raise Exception(
"New CPUs find (original %d CPUs, now %d CPUs)" %
(original_n_cpus, n_cpus))
else:
original_n_gpus = args.distributed_world_size
def _new_save_checkpoint(*args, **kwargs):
_original_save_checkpoint(*args, **kwargs)
n_gpus = int(ray.cluster_resources().get("GPU", 0))
if n_gpus > original_n_gpus:
raise Exception(
"New GPUs find (original %d GPUs, now %d GPUs)" %
(original_n_gpus, n_gpus))
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.util.get_node_ip_address()
def find_free_port(self):
"""Finds a free port on the current node."""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def run_fault_tolerant_loop():
"""Entrance function to the fairseq library, providing fault-tolerance."""
# Parse the command line arguments.
parser = options.get_training_parser()
add_ray_args(parser)
args = options.parse_args_and_arch(parser)
original_args = copy.deepcopy(args)
# Main loop for fault-tolerant training.
retry = True
while retry:
args = copy.deepcopy(original_args)
# Initialize Ray.
ray.init(address=args.ray_address)
set_num_resources(args)
set_batch_size(args)
# Set up Ray distributed actors.
Actor = ray.remote(
num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
workers = [Actor.remote() for i in range(args.distributed_world_size)]
# Get the IP address and a free port of actor 0, which is used for
# fairseq distributed training.
ip = ray.get(workers[0].get_node_ip.remote())
port = ray.get(workers[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
# Start the remote processes, and check whether their are any process
# fails. If so, restart all the processes.
unfinished = [
worker.run.remote(address, i, args)
for i, worker in enumerate(workers)
]
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)
finished = ray.get(finished)
retry = False
except Exception as inst:
print("Ray restart because following error occurs:")
print(inst)
retry = True
ray.shutdown()
def add_ray_args(parser):
"""Add ray and fault-tolerance related parser arguments to the parser."""
group = parser.add_argument_group("Ray related arguments")
group.add_argument(
"--ray-address",
default="auto",
type=str,
help="address for ray initialization")
group.add_argument(
"--fix-batch-size",
default=None,
metavar="B1,B2,...,B_N",
type=lambda uf: options.eval_str_list(uf, type=int),
help="fix the actual batch size (max_sentences * update_freq "
"* n_GPUs) to be the fixed input values by adjusting update_freq "
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
"epoch i; all epochs >N are fixed to B_N")
return group
def set_num_resources(args):
"""Get the number of resources and set the corresponding fields."""
if args.cpu:
args.distributed_world_size = int(ray.cluster_resources()["CPU"])
else:
n_gpus = int(ray.cluster_resources().get("GPU", 0))
while n_gpus == 0:
print("No GPUs available, wait 10 seconds")
time.sleep(10)
n_gpus = int(ray.cluster_resources().get("GPU", 0))
args.distributed_world_size = n_gpus
def set_batch_size(args):
"""Fixes the total batch_size to be agnostic to the GPU count."""
if args.fix_batch_size is not None:
args.update_freq = [
math.ceil(batch_size /
(args.max_sentences * args.distributed_world_size))
for batch_size in args.fix_batch_size
]
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
(args.distributed_world_size, args.max_sentences,
repr(args.update_freq)))
if __name__ == "__main__":
run_fault_tolerant_loop()