From e0e52f1871b529ba6a8d86f1547dc636c4c0f703 Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Wed, 12 Jun 2019 07:38:34 +0200 Subject: [PATCH] [sgd] Add non-distributed PyTorch runner (#4933) * Add non-distributed PyTorch runner * use dist.is_available() instead of checking OS * Nicer exception * Fix bug in choosing port * Refactor some code * Address comments * Address comments --- .../sgd/pytorch/distributed_pytorch_runner.py | 131 ++++++++++++++++++ .../sgd/pytorch/pytorch_runner.py | 105 +++++--------- .../sgd/pytorch/pytorch_trainer.py | 111 ++++++++------- python/ray/experimental/sgd/pytorch/utils.py | 2 +- .../experimental/sgd/tests/test_pytorch.py | 20 +-- 5 files changed, 237 insertions(+), 132 deletions(-) create mode 100644 python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py diff --git a/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py b/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py new file mode 100644 index 000000000..160544633 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py @@ -0,0 +1,131 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import torch.distributed as dist +import torch.utils.data + +from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner + +logger = logging.getLogger(__name__) + + +class DistributedPyTorchRunner(PyTorchRunner): + """Manages a distributed PyTorch model replica.""" + + def __init__(self, + model_creator, + data_creator, + optimizer_creator, + config=None, + batch_size=16, + backend="gloo"): + """Initializes the runner. + + Args: + model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. + data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + see pytorch_trainer.py. + config (dict): see pytorch_trainer.py. + batch_size (int): batch size used by one replica for an update. + backend (string): see pytorch_trainer.py. + """ + super(DistributedPyTorchRunner, self).__init__( + model_creator, data_creator, optimizer_creator, config, batch_size) + self.backend = backend + + def setup(self, url, world_rank, world_size): + """Connects to the distributed PyTorch backend and initializes the model. + + Args: + url (str): the URL used to connect to distributed PyTorch. + world_rank (int): the index of the runner. + world_size (int): the total number of runners. + """ + self._setup_distributed_pytorch(url, world_rank, world_size) + self._setup_training() + + def _setup_distributed_pytorch(self, url, world_rank, world_size): + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + with self._timers["setup_proc"]: + self.world_rank = world_rank + logger.debug( + "Connecting to {} world_rank: {} world_size: {}".format( + url, world_rank, world_size)) + logger.debug("using {}".format(self.backend)) + dist.init_process_group( + backend=self.backend, + init_method=url, + rank=world_rank, + world_size=world_size) + + def _setup_training(self): + logger.debug("Creating model") + self.model = self.model_creator(self.config) + if torch.cuda.is_available(): + self.model = torch.nn.parallel.DistributedDataParallel( + self.model.cuda()) + else: + self.model = torch.nn.parallel.DistributedDataParallelCPU( + self.model) + + logger.debug("Creating optimizer") + self.criterion, self.optimizer = self.optimizer_creator( + self.model, self.config) + if torch.cuda.is_available(): + self.criterion = self.criterion.cuda() + + logger.debug("Creating dataset") + self.training_set, self.validation_set = self.data_creator(self.config) + + # TODO: make num_workers configurable + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_set) + self.train_loader = torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=(self.train_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.train_sampler) + + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.validation_set)) + self.validation_loader = torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=(self.validation_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.validation_sampler) + + def step(self): + """Runs a training epoch and updates the model parameters.""" + logger.debug("Starting step") + self.train_sampler.set_epoch(self.epoch) + return super(DistributedPyTorchRunner, self).step() + + def get_state(self): + """Returns the state of the runner.""" + return { + "epoch": self.epoch, + "model": self.model.module.state_dict(), + "optimizer": self.optimizer.state_dict(), + "stats": self.stats() + } + + def set_state(self, state): + """Sets the state of the model.""" + # TODO: restore timer stats + self.model.module.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.epoch = state["stats"]["epoch"] + + def shutdown(self): + """Attempts to shut down the worker.""" + super(DistributedPyTorchRunner, self).shutdown() + dist.destroy_process_group() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_runner.py b/python/ray/experimental/sgd/pytorch/pytorch_runner.py index 5fe4ba100..1663b2c64 100644 --- a/python/ray/experimental/sgd/pytorch/pytorch_runner.py +++ b/python/ray/experimental/sgd/pytorch/pytorch_runner.py @@ -3,9 +3,7 @@ from __future__ import division from __future__ import print_function import logging -import os import torch -import torch.distributed as dist import torch.utils.data import ray @@ -15,28 +13,23 @@ logger = logging.getLogger(__name__) class PyTorchRunner(object): - """Manages a distributed PyTorch model replica""" + """Manages a PyTorch model for training.""" def __init__(self, model_creator, data_creator, optimizer_creator, config=None, - batch_size=16, - backend="gloo"): + batch_size=16): """Initializes the runner. Args: - model_creator (dict -> torch.nn.Module): creates the model using - the config. - data_creator (dict -> Dataset, Dataset): creates the training and - validation data sets using the config. + model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. + data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. optimizer_creator (torch.nn.Module, dict -> loss, optimizer): - creates the loss and optimizer using the model and the config. - config (dict): configuration passed to 'model_creator', - 'data_creator', and 'optimizer_creator'. - batch_size (int): batch size used in an update. - backend (string): backend used by distributed PyTorch. + see pytorch_trainer.py. + config (dict): see pytorch_trainer.py. + batch_size (int): see pytorch_trainer.py. """ self.model_creator = model_creator @@ -44,7 +37,6 @@ class PyTorchRunner(object): self.optimizer_creator = optimizer_creator self.config = {} if config is None else config self.batch_size = batch_size - self.backend = backend self.verbose = True self.epoch = 0 @@ -56,82 +48,45 @@ class PyTorchRunner(object): ] } - def setup(self, url, world_rank, world_size): - """Connects to the distributed PyTorch backend and initializes the model. - - Args: - url (str): the URL used to connect to distributed PyTorch. - world_rank (int): the index of the runner. - world_size (int): the total number of runners. - """ - self._setup_distributed_pytorch(url, world_rank, world_size) - self._setup_training() - - def _setup_distributed_pytorch(self, url, world_rank, world_size): - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - with self._timers["setup_proc"]: - self.world_rank = world_rank - logger.debug( - "Connecting to {} world_rank: {} world_size: {}".format( - url, world_rank, world_size)) - logger.debug("using {}".format(self.backend)) - dist.init_process_group( - backend=self.backend, - init_method=url, - rank=world_rank, - world_size=world_size) - - def _setup_training(self): + def setup(self): + """Initializes the model.""" logger.debug("Creating model") self.model = self.model_creator(self.config) if torch.cuda.is_available(): - self.model = torch.nn.parallel.DistributedDataParallel( - self.model.cuda()) - else: - self.model = torch.nn.parallel.DistributedDataParallelCPU( - self.model) + self.model = self.model.cuda() logger.debug("Creating optimizer") self.criterion, self.optimizer = self.optimizer_creator( self.model, self.config) - if torch.cuda.is_available(): self.criterion = self.criterion.cuda() logger.debug("Creating dataset") self.training_set, self.validation_set = self.data_creator(self.config) - - # TODO: make num_workers configurable - self.train_sampler = torch.utils.data.distributed.DistributedSampler( - self.training_set) self.train_loader = torch.utils.data.DataLoader( self.training_set, batch_size=self.batch_size, - shuffle=(self.train_sampler is None), + shuffle=True, num_workers=2, - pin_memory=False, - sampler=self.train_sampler) + pin_memory=False) - self.validation_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.validation_set)) self.validation_loader = torch.utils.data.DataLoader( self.validation_set, batch_size=self.batch_size, - shuffle=(self.validation_sampler is None), + shuffle=True, num_workers=2, - pin_memory=False, - sampler=self.validation_sampler) + pin_memory=False) def get_node_ip(self): - """Returns the IP address of the current node""" + """Returns the IP address of the current node.""" return ray.services.get_node_ip_address() - def step(self): - """Runs a training epoch and updates the model parameters""" - logger.debug("Starting step") - self.train_sampler.set_epoch(self.epoch) + def find_free_port(self): + """Finds a free port on the current node.""" + return utils.find_free_port() + def step(self): + """Runs a training epoch and updates the model parameters.""" logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) with self._timers["training"]: train_stats = utils.train(self.train_loader, self.model, @@ -144,7 +99,7 @@ class PyTorchRunner(object): return train_stats def validate(self): - """Evaluates the model on the validation data set""" + """Evaluates the model on the validation data set.""" with self._timers["validation"]: validation_stats = utils.validate(self.validation_loader, self.model, self.criterion) @@ -153,7 +108,7 @@ class PyTorchRunner(object): return validation_stats def stats(self): - """Returns a dictionary of statistics collected""" + """Returns a dictionary of statistics collected.""" stats = {"epoch": self.epoch} for k, t in self._timers.items(): stats[k + "_time_mean"] = t.mean @@ -162,7 +117,7 @@ class PyTorchRunner(object): return stats def get_state(self): - """Returns the state of the runner""" + """Returns the state of the runner.""" return { "epoch": self.epoch, "model": self.model.state_dict(), @@ -171,12 +126,20 @@ class PyTorchRunner(object): } def set_state(self, state): - """Sets the state of the model""" + """Sets the state of the model.""" # TODO: restore timer stats self.model.load_state_dict(state["model"]) self.optimizer.load_state_dict(state["optimizer"]) self.epoch = state["stats"]["epoch"] def shutdown(self): - """Attempts to shut down the worker""" - dist.destroy_process_group() + """Attempts to shut down the worker.""" + del self.validation_loader + del self.validation_set + del self.train_loader + del self.training_set + del self.criterion + del self.optimizer + del self.model + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py index 073ad3d34..0e0c5d843 100644 --- a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py +++ b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py @@ -3,13 +3,15 @@ from __future__ import division from __future__ import print_function import numpy as np -import sys import torch +import torch.distributed as dist import logging import ray from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner +from ray.experimental.sgd.pytorch.distributed_pytorch_runner import ( + DistributedPyTorchRunner) from ray.experimental.sgd.pytorch import utils logger = logging.getLogger(__name__) @@ -51,10 +53,11 @@ class PyTorchTrainer(object): """ # TODO: add support for mixed precision # TODO: add support for callbacks - if sys.platform == "darwin": - raise Exception( - ("Distributed PyTorch is not supported on macOS. For more " - "information, see " + if num_replicas > 1 and not dist.is_available(): + raise ValueError( + ("Distributed PyTorch is not supported on macOS. " + "To run without distributed PyTorch, set 'num_replicas=1'. " + "For more information, see " "https://github.com/pytorch/examples/issues/467.")) self.model_creator = model_creator @@ -68,40 +71,55 @@ class PyTorchTrainer(object): if backend == "auto": backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo" - Runner = ray.remote( - num_cpus=resources_per_replica.num_cpus, - num_gpus=resources_per_replica.num_gpus, - resources=resources_per_replica.resources)(PyTorchRunner) - - batch_size_per_replica = batch_size // num_replicas - if batch_size % num_replicas > 0: - new_batch_size = batch_size_per_replica * num_replicas - logger.warn( - ("Changing batch size from {old_batch_size} to " - "{new_batch_size} to evenly distribute batches across " - "{num_replicas} replicas.").format( - old_batch_size=batch_size, - new_batch_size=new_batch_size, - num_replicas=num_replicas)) - - self.workers = [ - Runner.remote(model_creator, data_creator, optimizer_creator, - self.config, batch_size_per_replica, backend) - for i in range(num_replicas) - ] - - ip = ray.get(self.workers[0].get_node_ip.remote()) - port = utils.find_free_port() - address = "tcp://{ip}:{port}".format(ip=ip, port=port) - - # Get setup tasks in order to throw errors on failure - ray.get([ - worker.setup.remote(address, i, len(self.workers)) - for i, worker in enumerate(self.workers) - ]) + if num_replicas == 1: + # Generate actor class + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)(PyTorchRunner) + # Start workers + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size) + ] + # Get setup tasks in order to throw errors on failure + ray.get(self.workers[0].setup.remote()) + else: + # Geneate actor class + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)( + DistributedPyTorchRunner) + # Compute batch size per replica + batch_size_per_replica = batch_size // num_replicas + if batch_size % num_replicas > 0: + new_batch_size = batch_size_per_replica * num_replicas + logger.warn( + ("Changing batch size from {old_batch_size} to " + "{new_batch_size} to evenly distribute batches across " + "{num_replicas} replicas.").format( + old_batch_size=batch_size, + new_batch_size=new_batch_size, + num_replicas=num_replicas)) + # Start workers + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size_per_replica, backend) + for i in range(num_replicas) + ] + # Compute URL for initializing distributed PyTorch + ip = ray.get(self.workers[0].get_node_ip.remote()) + port = ray.get(self.workers[0].find_free_port.remote()) + address = "tcp://{ip}:{port}".format(ip=ip, port=port) + # Get setup tasks in order to throw errors on failure + ray.get([ + worker.setup.remote(address, i, len(self.workers)) + for i, worker in enumerate(self.workers) + ]) def train(self): - """Runs a training epoch""" + """Runs a training epoch.""" with self.optimizer_timer: worker_stats = ray.get([w.step.remote() for w in self.workers]) @@ -111,7 +129,7 @@ class PyTorchTrainer(object): return train_stats def validate(self): - """Evaluates the model on the validation data set""" + """Evaluates the model on the validation data set.""" worker_stats = ray.get([w.validate.remote() for w in self.workers]) validation_stats = worker_stats[0].copy() validation_stats["validation_loss"] = np.mean( @@ -119,32 +137,25 @@ class PyTorchTrainer(object): return validation_stats def get_model(self): - """Returns the learned model""" + """Returns the learned model.""" model = self.model_creator(self.config) state = ray.get(self.workers[0].get_state.remote()) - - # Remove module. prefix added by distrbuted pytorch - state_dict = { - k.replace("module.", ""): v - for k, v in state["model"].items() - } - - model.load_state_dict(state_dict) + model.load_state_dict(state["model"]) return model def save(self, ckpt): - """Saves the model at the provided checkpoint""" + """Saves the model at the provided checkpoint.""" state = ray.get(self.workers[0].get_state.remote()) torch.save(state, ckpt) def restore(self, ckpt): - """Restores the model from the provided checkpoint""" + """Restores the model from the provided checkpoint.""" state = torch.load(ckpt) state_id = ray.put(state) ray.get([worker.set_state.remote(state_id) for worker in self.workers]) def shutdown(self): - """Shuts down workers and releases resources""" + """Shuts down workers and releases resources.""" for worker in self.workers: worker.shutdown.remote() worker.__ray_terminate__.remote() diff --git a/python/ray/experimental/sgd/pytorch/utils.py b/python/ray/experimental/sgd/pytorch/utils.py index f7c6e4aba..5be26b331 100644 --- a/python/ray/experimental/sgd/pytorch/utils.py +++ b/python/ray/experimental/sgd/pytorch/utils.py @@ -196,7 +196,7 @@ def find_free_port(): class AverageMeter(object): - """Computes and stores the average and current value""" + """Computes and stores the average and current value.""" def __init__(self): self.reset() diff --git a/python/ray/experimental/sgd/tests/test_pytorch.py b/python/ray/experimental/sgd/tests/test_pytorch.py index faff23f8a..aa0596aa1 100644 --- a/python/ray/experimental/sgd/tests/test_pytorch.py +++ b/python/ray/experimental/sgd/tests/test_pytorch.py @@ -4,9 +4,9 @@ from __future__ import print_function import os import pytest -import sys import tempfile import torch +import torch.distributed as dist from ray.tests.conftest import ray_start_2_cpus # noqa: F401 from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources @@ -15,14 +15,14 @@ from ray.experimental.sgd.tests.pytorch_utils import ( model_creator, optimizer_creator, data_creator) -@pytest.mark.skipif( # noqa: F811 - sys.platform == "darwin", reason="Doesn't work on macOS.") -def test_train(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize( # noqa: F811 + "num_replicas", [1, 2] if dist.is_available() else [1]) +def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["validation_loss"] @@ -37,14 +37,14 @@ def test_train(ray_start_2_cpus): # noqa: F811 assert validation_loss2 <= validation_loss1 -@pytest.mark.skipif( # noqa: F811 - sys.platform == "darwin", reason="Doesn't work on macOS.") -def test_save_and_restore(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize( # noqa: F811 + "num_replicas", [1, 2] if dist.is_available() else [1]) +def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811 trainer1 = PyTorchTrainer( model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) trainer1.train() @@ -59,7 +59,7 @@ def test_save_and_restore(ray_start_2_cpus): # noqa: F811 model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) trainer2.restore(filename)