mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
181 lines
6.2 KiB
Python
181 lines
6.2 KiB
Python
#!/usr/bin/env python3 -u
|
|
|
|
import math
|
|
import copy
|
|
import socket
|
|
import time
|
|
|
|
import ray
|
|
|
|
import fairseq
|
|
from fairseq import options
|
|
from fairseq_cli.train import main
|
|
from contextlib import closing
|
|
|
|
_original_save_checkpoint = fairseq.checkpoint_utils.save_checkpoint
|
|
|
|
|
|
class RayDistributedActor:
|
|
"""Actor to perform distributed training."""
|
|
|
|
def run(self, url, world_rank, args):
|
|
"""Runs the fairseq training.
|
|
|
|
We set args for different ray actors for communication,
|
|
add a checkpoint hook, and call the main function of fairseq.
|
|
"""
|
|
|
|
# Set the init_method and rank of the process for distributed training.
|
|
print("Ray worker at {url} rank {rank}".format(
|
|
url=url, rank=world_rank))
|
|
self.url = url
|
|
self.world_rank = world_rank
|
|
args.distributed_rank = world_rank
|
|
args.distributed_init_method = url
|
|
|
|
# Add a checkpoint hook to make use of new resources.
|
|
self.add_checkpoint_hook(args)
|
|
|
|
# Call the original main function of fairseq.
|
|
main(args, init_distributed=(args.distributed_world_size > 1))
|
|
|
|
def add_checkpoint_hook(self, args):
|
|
"""Add a hook to the original save_checkpoint function.
|
|
|
|
This checks if there are new computational resources available.
|
|
If so, raise exception to restart the training process and
|
|
make use of the new resources.
|
|
"""
|
|
|
|
if args.cpu:
|
|
original_n_cpus = args.distributed_world_size
|
|
|
|
def _new_save_checkpoint(*args, **kwargs):
|
|
_original_save_checkpoint(*args, **kwargs)
|
|
n_cpus = int(ray.cluster_resources()["CPU"])
|
|
if n_cpus > original_n_cpus:
|
|
raise Exception(
|
|
"New CPUs find (original %d CPUs, now %d CPUs)" %
|
|
(original_n_cpus, n_cpus))
|
|
else:
|
|
original_n_gpus = args.distributed_world_size
|
|
|
|
def _new_save_checkpoint(*args, **kwargs):
|
|
_original_save_checkpoint(*args, **kwargs)
|
|
n_gpus = int(ray.cluster_resources().get("GPU", 0))
|
|
if n_gpus > original_n_gpus:
|
|
raise Exception(
|
|
"New GPUs find (original %d GPUs, now %d GPUs)" %
|
|
(original_n_gpus, n_gpus))
|
|
|
|
fairseq.checkpoint_utils.save_checkpoint = _new_save_checkpoint
|
|
|
|
def get_node_ip(self):
|
|
"""Returns the IP address of the current node."""
|
|
return ray.util.get_node_ip_address()
|
|
|
|
def find_free_port(self):
|
|
"""Finds a free port on the current node."""
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
s.bind(("", 0))
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def run_fault_tolerant_loop():
|
|
"""Entrance function to the fairseq library, providing fault-tolerance."""
|
|
|
|
# Parse the command line arguments.
|
|
parser = options.get_training_parser()
|
|
add_ray_args(parser)
|
|
args = options.parse_args_and_arch(parser)
|
|
original_args = copy.deepcopy(args)
|
|
|
|
# Main loop for fault-tolerant training.
|
|
retry = True
|
|
while retry:
|
|
args = copy.deepcopy(original_args)
|
|
|
|
# Initialize Ray.
|
|
ray.init(address=args.ray_address)
|
|
|
|
set_num_resources(args)
|
|
set_batch_size(args)
|
|
|
|
# Set up Ray distributed actors.
|
|
Actor = ray.remote(
|
|
num_cpus=1, num_gpus=int(not args.cpu))(RayDistributedActor)
|
|
workers = [Actor.remote() for i in range(args.distributed_world_size)]
|
|
|
|
# Get the IP address and a free port of actor 0, which is used for
|
|
# fairseq distributed training.
|
|
ip = ray.get(workers[0].get_node_ip.remote())
|
|
port = ray.get(workers[0].find_free_port.remote())
|
|
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
|
|
|
# Start the remote processes, and check whether their are any process
|
|
# fails. If so, restart all the processes.
|
|
unfinished = [
|
|
worker.run.remote(address, i, args)
|
|
for i, worker in enumerate(workers)
|
|
]
|
|
try:
|
|
while len(unfinished) > 0:
|
|
finished, unfinished = ray.wait(unfinished)
|
|
finished = ray.get(finished)
|
|
retry = False
|
|
except Exception as inst:
|
|
print("Ray restart because following error occurs:")
|
|
print(inst)
|
|
retry = True
|
|
ray.shutdown()
|
|
|
|
|
|
def add_ray_args(parser):
|
|
"""Add ray and fault-tolerance related parser arguments to the parser."""
|
|
group = parser.add_argument_group("Ray related arguments")
|
|
group.add_argument(
|
|
"--ray-address",
|
|
default="auto",
|
|
type=str,
|
|
help="address for ray initialization")
|
|
group.add_argument(
|
|
"--fix-batch-size",
|
|
default=None,
|
|
metavar="B1,B2,...,B_N",
|
|
type=lambda uf: options.eval_str_list(uf, type=int),
|
|
help="fix the actual batch size (max_sentences * update_freq "
|
|
"* n_GPUs) to be the fixed input values by adjusting update_freq "
|
|
"accroding to actual n_GPUs; the batch size is fixed to B_i for "
|
|
"epoch i; all epochs >N are fixed to B_N")
|
|
return group
|
|
|
|
|
|
def set_num_resources(args):
|
|
"""Get the number of resources and set the corresponding fields."""
|
|
if args.cpu:
|
|
args.distributed_world_size = int(ray.cluster_resources()["CPU"])
|
|
else:
|
|
n_gpus = int(ray.cluster_resources().get("GPU", 0))
|
|
while n_gpus == 0:
|
|
print("No GPUs available, wait 10 seconds")
|
|
time.sleep(10)
|
|
n_gpus = int(ray.cluster_resources().get("GPU", 0))
|
|
args.distributed_world_size = n_gpus
|
|
|
|
|
|
def set_batch_size(args):
|
|
"""Fixes the total batch_size to be agnostic to the GPU count."""
|
|
if args.fix_batch_size is not None:
|
|
args.update_freq = [
|
|
math.ceil(batch_size /
|
|
(args.max_sentences * args.distributed_world_size))
|
|
for batch_size in args.fix_batch_size
|
|
]
|
|
print("Training on %d GPUs, max_sentences=%d, update_freq=%s" %
|
|
(args.distributed_world_size, args.max_sentences,
|
|
repr(args.update_freq)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_fault_tolerant_loop()
|