[sgd] Add non-distributed PyTorch runner (#4933)

* Add non-distributed PyTorch runner

* use dist.is_available() instead of checking OS

* Nicer exception

* Fix bug in choosing port

* Refactor some code

* Address comments

* Address comments
This commit is contained in:
Peter Schafhalter 2019-06-12 07:38:34 +02:00 committed by Richard Liaw
parent 472c36ed1e
commit e0e52f1871
5 changed files with 237 additions and 132 deletions

View file

@ -0,0 +1,131 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import torch.distributed as dist
import torch.utils.data
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
logger = logging.getLogger(__name__)
class DistributedPyTorchRunner(PyTorchRunner):
"""Manages a distributed PyTorch model replica."""
def __init__(self,
model_creator,
data_creator,
optimizer_creator,
config=None,
batch_size=16,
backend="gloo"):
"""Initializes the runner.
Args:
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
see pytorch_trainer.py.
config (dict): see pytorch_trainer.py.
batch_size (int): batch size used by one replica for an update.
backend (string): see pytorch_trainer.py.
"""
super(DistributedPyTorchRunner, self).__init__(
model_creator, data_creator, optimizer_creator, config, batch_size)
self.backend = backend
def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()
def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)
def _setup_training(self):
logger.debug("Creating model")
self.model = self.model_creator(self.config)
if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel(
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)
logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config)
if torch.cuda.is_available():
self.criterion = self.criterion.cuda()
logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config)
# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=(self.train_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.train_sampler)
self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader(
self.validation_set,
batch_size=self.batch_size,
shuffle=(self.validation_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.validation_sampler)
def step(self):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Starting step")
self.train_sampler.set_epoch(self.epoch)
return super(DistributedPyTorchRunner, self).step()
def get_state(self):
"""Returns the state of the runner."""
return {
"epoch": self.epoch,
"model": self.model.module.state_dict(),
"optimizer": self.optimizer.state_dict(),
"stats": self.stats()
}
def set_state(self, state):
"""Sets the state of the model."""
# TODO: restore timer stats
self.model.module.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"]
def shutdown(self):
"""Attempts to shut down the worker."""
super(DistributedPyTorchRunner, self).shutdown()
dist.destroy_process_group()

View file

@ -3,9 +3,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import logging import logging
import os
import torch import torch
import torch.distributed as dist
import torch.utils.data import torch.utils.data
import ray import ray
@ -15,28 +13,23 @@ logger = logging.getLogger(__name__)
class PyTorchRunner(object): class PyTorchRunner(object):
"""Manages a distributed PyTorch model replica""" """Manages a PyTorch model for training."""
def __init__(self, def __init__(self,
model_creator, model_creator,
data_creator, data_creator,
optimizer_creator, optimizer_creator,
config=None, config=None,
batch_size=16, batch_size=16):
backend="gloo"):
"""Initializes the runner. """Initializes the runner.
Args: Args:
model_creator (dict -> torch.nn.Module): creates the model using model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
the config. data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
data_creator (dict -> Dataset, Dataset): creates the training and
validation data sets using the config.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer): optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config. see pytorch_trainer.py.
config (dict): configuration passed to 'model_creator', config (dict): see pytorch_trainer.py.
'data_creator', and 'optimizer_creator'. batch_size (int): see pytorch_trainer.py.
batch_size (int): batch size used in an update.
backend (string): backend used by distributed PyTorch.
""" """
self.model_creator = model_creator self.model_creator = model_creator
@ -44,7 +37,6 @@ class PyTorchRunner(object):
self.optimizer_creator = optimizer_creator self.optimizer_creator = optimizer_creator
self.config = {} if config is None else config self.config = {} if config is None else config
self.batch_size = batch_size self.batch_size = batch_size
self.backend = backend
self.verbose = True self.verbose = True
self.epoch = 0 self.epoch = 0
@ -56,82 +48,45 @@ class PyTorchRunner(object):
] ]
} }
def setup(self, url, world_rank, world_size): def setup(self):
"""Connects to the distributed PyTorch backend and initializes the model. """Initializes the model."""
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()
def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)
def _setup_training(self):
logger.debug("Creating model") logger.debug("Creating model")
self.model = self.model_creator(self.config) self.model = self.model_creator(self.config)
if torch.cuda.is_available(): if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel( self.model = self.model.cuda()
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)
logger.debug("Creating optimizer") logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator( self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config) self.model, self.config)
if torch.cuda.is_available(): if torch.cuda.is_available():
self.criterion = self.criterion.cuda() self.criterion = self.criterion.cuda()
logger.debug("Creating dataset") logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config) self.training_set, self.validation_set = self.data_creator(self.config)
# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader( self.train_loader = torch.utils.data.DataLoader(
self.training_set, self.training_set,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=(self.train_sampler is None), shuffle=True,
num_workers=2, num_workers=2,
pin_memory=False, pin_memory=False)
sampler=self.train_sampler)
self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader( self.validation_loader = torch.utils.data.DataLoader(
self.validation_set, self.validation_set,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=(self.validation_sampler is None), shuffle=True,
num_workers=2, num_workers=2,
pin_memory=False, pin_memory=False)
sampler=self.validation_sampler)
def get_node_ip(self): def get_node_ip(self):
"""Returns the IP address of the current node""" """Returns the IP address of the current node."""
return ray.services.get_node_ip_address() return ray.services.get_node_ip_address()
def step(self): def find_free_port(self):
"""Runs a training epoch and updates the model parameters""" """Finds a free port on the current node."""
logger.debug("Starting step") return utils.find_free_port()
self.train_sampler.set_epoch(self.epoch)
def step(self):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
with self._timers["training"]: with self._timers["training"]:
train_stats = utils.train(self.train_loader, self.model, train_stats = utils.train(self.train_loader, self.model,
@ -144,7 +99,7 @@ class PyTorchRunner(object):
return train_stats return train_stats
def validate(self): def validate(self):
"""Evaluates the model on the validation data set""" """Evaluates the model on the validation data set."""
with self._timers["validation"]: with self._timers["validation"]:
validation_stats = utils.validate(self.validation_loader, validation_stats = utils.validate(self.validation_loader,
self.model, self.criterion) self.model, self.criterion)
@ -153,7 +108,7 @@ class PyTorchRunner(object):
return validation_stats return validation_stats
def stats(self): def stats(self):
"""Returns a dictionary of statistics collected""" """Returns a dictionary of statistics collected."""
stats = {"epoch": self.epoch} stats = {"epoch": self.epoch}
for k, t in self._timers.items(): for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean stats[k + "_time_mean"] = t.mean
@ -162,7 +117,7 @@ class PyTorchRunner(object):
return stats return stats
def get_state(self): def get_state(self):
"""Returns the state of the runner""" """Returns the state of the runner."""
return { return {
"epoch": self.epoch, "epoch": self.epoch,
"model": self.model.state_dict(), "model": self.model.state_dict(),
@ -171,12 +126,20 @@ class PyTorchRunner(object):
} }
def set_state(self, state): def set_state(self, state):
"""Sets the state of the model""" """Sets the state of the model."""
# TODO: restore timer stats # TODO: restore timer stats
self.model.load_state_dict(state["model"]) self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"]) self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"] self.epoch = state["stats"]["epoch"]
def shutdown(self): def shutdown(self):
"""Attempts to shut down the worker""" """Attempts to shut down the worker."""
dist.destroy_process_group() del self.validation_loader
del self.validation_set
del self.train_loader
del self.training_set
del self.criterion
del self.optimizer
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()

View file

@ -3,13 +3,15 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import sys
import torch import torch
import torch.distributed as dist
import logging import logging
import ray import ray
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
from ray.experimental.sgd.pytorch.distributed_pytorch_runner import (
DistributedPyTorchRunner)
from ray.experimental.sgd.pytorch import utils from ray.experimental.sgd.pytorch import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,10 +53,11 @@ class PyTorchTrainer(object):
""" """
# TODO: add support for mixed precision # TODO: add support for mixed precision
# TODO: add support for callbacks # TODO: add support for callbacks
if sys.platform == "darwin": if num_replicas > 1 and not dist.is_available():
raise Exception( raise ValueError(
("Distributed PyTorch is not supported on macOS. For more " ("Distributed PyTorch is not supported on macOS. "
"information, see " "To run without distributed PyTorch, set 'num_replicas=1'. "
"For more information, see "
"https://github.com/pytorch/examples/issues/467.")) "https://github.com/pytorch/examples/issues/467."))
self.model_creator = model_creator self.model_creator = model_creator
@ -68,40 +71,55 @@ class PyTorchTrainer(object):
if backend == "auto": if backend == "auto":
backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo" backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo"
Runner = ray.remote( if num_replicas == 1:
num_cpus=resources_per_replica.num_cpus, # Generate actor class
num_gpus=resources_per_replica.num_gpus, Runner = ray.remote(
resources=resources_per_replica.resources)(PyTorchRunner) num_cpus=resources_per_replica.num_cpus,
num_gpus=resources_per_replica.num_gpus,
batch_size_per_replica = batch_size // num_replicas resources=resources_per_replica.resources)(PyTorchRunner)
if batch_size % num_replicas > 0: # Start workers
new_batch_size = batch_size_per_replica * num_replicas self.workers = [
logger.warn( Runner.remote(model_creator, data_creator, optimizer_creator,
("Changing batch size from {old_batch_size} to " self.config, batch_size)
"{new_batch_size} to evenly distribute batches across " ]
"{num_replicas} replicas.").format( # Get setup tasks in order to throw errors on failure
old_batch_size=batch_size, ray.get(self.workers[0].setup.remote())
new_batch_size=new_batch_size, else:
num_replicas=num_replicas)) # Geneate actor class
Runner = ray.remote(
self.workers = [ num_cpus=resources_per_replica.num_cpus,
Runner.remote(model_creator, data_creator, optimizer_creator, num_gpus=resources_per_replica.num_gpus,
self.config, batch_size_per_replica, backend) resources=resources_per_replica.resources)(
for i in range(num_replicas) DistributedPyTorchRunner)
] # Compute batch size per replica
batch_size_per_replica = batch_size // num_replicas
ip = ray.get(self.workers[0].get_node_ip.remote()) if batch_size % num_replicas > 0:
port = utils.find_free_port() new_batch_size = batch_size_per_replica * num_replicas
address = "tcp://{ip}:{port}".format(ip=ip, port=port) logger.warn(
("Changing batch size from {old_batch_size} to "
# Get setup tasks in order to throw errors on failure "{new_batch_size} to evenly distribute batches across "
ray.get([ "{num_replicas} replicas.").format(
worker.setup.remote(address, i, len(self.workers)) old_batch_size=batch_size,
for i, worker in enumerate(self.workers) new_batch_size=new_batch_size,
]) num_replicas=num_replicas))
# Start workers
self.workers = [
Runner.remote(model_creator, data_creator, optimizer_creator,
self.config, batch_size_per_replica, backend)
for i in range(num_replicas)
]
# Compute URL for initializing distributed PyTorch
ip = ray.get(self.workers[0].get_node_ip.remote())
port = ray.get(self.workers[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
# Get setup tasks in order to throw errors on failure
ray.get([
worker.setup.remote(address, i, len(self.workers))
for i, worker in enumerate(self.workers)
])
def train(self): def train(self):
"""Runs a training epoch""" """Runs a training epoch."""
with self.optimizer_timer: with self.optimizer_timer:
worker_stats = ray.get([w.step.remote() for w in self.workers]) worker_stats = ray.get([w.step.remote() for w in self.workers])
@ -111,7 +129,7 @@ class PyTorchTrainer(object):
return train_stats return train_stats
def validate(self): def validate(self):
"""Evaluates the model on the validation data set""" """Evaluates the model on the validation data set."""
worker_stats = ray.get([w.validate.remote() for w in self.workers]) worker_stats = ray.get([w.validate.remote() for w in self.workers])
validation_stats = worker_stats[0].copy() validation_stats = worker_stats[0].copy()
validation_stats["validation_loss"] = np.mean( validation_stats["validation_loss"] = np.mean(
@ -119,32 +137,25 @@ class PyTorchTrainer(object):
return validation_stats return validation_stats
def get_model(self): def get_model(self):
"""Returns the learned model""" """Returns the learned model."""
model = self.model_creator(self.config) model = self.model_creator(self.config)
state = ray.get(self.workers[0].get_state.remote()) state = ray.get(self.workers[0].get_state.remote())
model.load_state_dict(state["model"])
# Remove module. prefix added by distrbuted pytorch
state_dict = {
k.replace("module.", ""): v
for k, v in state["model"].items()
}
model.load_state_dict(state_dict)
return model return model
def save(self, ckpt): def save(self, ckpt):
"""Saves the model at the provided checkpoint""" """Saves the model at the provided checkpoint."""
state = ray.get(self.workers[0].get_state.remote()) state = ray.get(self.workers[0].get_state.remote())
torch.save(state, ckpt) torch.save(state, ckpt)
def restore(self, ckpt): def restore(self, ckpt):
"""Restores the model from the provided checkpoint""" """Restores the model from the provided checkpoint."""
state = torch.load(ckpt) state = torch.load(ckpt)
state_id = ray.put(state) state_id = ray.put(state)
ray.get([worker.set_state.remote(state_id) for worker in self.workers]) ray.get([worker.set_state.remote(state_id) for worker in self.workers])
def shutdown(self): def shutdown(self):
"""Shuts down workers and releases resources""" """Shuts down workers and releases resources."""
for worker in self.workers: for worker in self.workers:
worker.shutdown.remote() worker.shutdown.remote()
worker.__ray_terminate__.remote() worker.__ray_terminate__.remote()

View file

@ -196,7 +196,7 @@ def find_free_port():
class AverageMeter(object): class AverageMeter(object):
"""Computes and stores the average and current value""" """Computes and stores the average and current value."""
def __init__(self): def __init__(self):
self.reset() self.reset()

View file

@ -4,9 +4,9 @@ from __future__ import print_function
import os import os
import pytest import pytest
import sys
import tempfile import tempfile
import torch import torch
import torch.distributed as dist
from ray.tests.conftest import ray_start_2_cpus # noqa: F401 from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
@ -15,14 +15,14 @@ from ray.experimental.sgd.tests.pytorch_utils import (
model_creator, optimizer_creator, data_creator) model_creator, optimizer_creator, data_creator)
@pytest.mark.skipif( # noqa: F811 @pytest.mark.parametrize( # noqa: F811
sys.platform == "darwin", reason="Doesn't work on macOS.") "num_replicas", [1, 2] if dist.is_available() else [1])
def test_train(ray_start_2_cpus): # noqa: F811 def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
trainer = PyTorchTrainer( trainer = PyTorchTrainer(
model_creator, model_creator,
data_creator, data_creator,
optimizer_creator, optimizer_creator,
num_replicas=2, num_replicas=num_replicas,
resources_per_replica=Resources(num_cpus=1)) resources_per_replica=Resources(num_cpus=1))
train_loss1 = trainer.train()["train_loss"] train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["validation_loss"] validation_loss1 = trainer.validate()["validation_loss"]
@ -37,14 +37,14 @@ def test_train(ray_start_2_cpus): # noqa: F811
assert validation_loss2 <= validation_loss1 assert validation_loss2 <= validation_loss1
@pytest.mark.skipif( # noqa: F811 @pytest.mark.parametrize( # noqa: F811
sys.platform == "darwin", reason="Doesn't work on macOS.") "num_replicas", [1, 2] if dist.is_available() else [1])
def test_save_and_restore(ray_start_2_cpus): # noqa: F811 def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
trainer1 = PyTorchTrainer( trainer1 = PyTorchTrainer(
model_creator, model_creator,
data_creator, data_creator,
optimizer_creator, optimizer_creator,
num_replicas=2, num_replicas=num_replicas,
resources_per_replica=Resources(num_cpus=1)) resources_per_replica=Resources(num_cpus=1))
trainer1.train() trainer1.train()
@ -59,7 +59,7 @@ def test_save_and_restore(ray_start_2_cpus): # noqa: F811
model_creator, model_creator,
data_creator, data_creator,
optimizer_creator, optimizer_creator,
num_replicas=2, num_replicas=num_replicas,
resources_per_replica=Resources(num_cpus=1)) resources_per_replica=Resources(num_cpus=1))
trainer2.restore(filename) trainer2.restore(filename)