mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[sgd] Add non-distributed PyTorch runner (#4933)
* Add non-distributed PyTorch runner * use dist.is_available() instead of checking OS * Nicer exception * Fix bug in choosing port * Refactor some code * Address comments * Address comments
This commit is contained in:
parent
472c36ed1e
commit
e0e52f1871
5 changed files with 237 additions and 132 deletions
|
@ -0,0 +1,131 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedPyTorchRunner(PyTorchRunner):
|
||||||
|
"""Manages a distributed PyTorch model replica."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_creator,
|
||||||
|
data_creator,
|
||||||
|
optimizer_creator,
|
||||||
|
config=None,
|
||||||
|
batch_size=16,
|
||||||
|
backend="gloo"):
|
||||||
|
"""Initializes the runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
|
||||||
|
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
|
||||||
|
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
||||||
|
see pytorch_trainer.py.
|
||||||
|
config (dict): see pytorch_trainer.py.
|
||||||
|
batch_size (int): batch size used by one replica for an update.
|
||||||
|
backend (string): see pytorch_trainer.py.
|
||||||
|
"""
|
||||||
|
super(DistributedPyTorchRunner, self).__init__(
|
||||||
|
model_creator, data_creator, optimizer_creator, config, batch_size)
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
def setup(self, url, world_rank, world_size):
|
||||||
|
"""Connects to the distributed PyTorch backend and initializes the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): the URL used to connect to distributed PyTorch.
|
||||||
|
world_rank (int): the index of the runner.
|
||||||
|
world_size (int): the total number of runners.
|
||||||
|
"""
|
||||||
|
self._setup_distributed_pytorch(url, world_rank, world_size)
|
||||||
|
self._setup_training()
|
||||||
|
|
||||||
|
def _setup_distributed_pytorch(self, url, world_rank, world_size):
|
||||||
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
with self._timers["setup_proc"]:
|
||||||
|
self.world_rank = world_rank
|
||||||
|
logger.debug(
|
||||||
|
"Connecting to {} world_rank: {} world_size: {}".format(
|
||||||
|
url, world_rank, world_size))
|
||||||
|
logger.debug("using {}".format(self.backend))
|
||||||
|
dist.init_process_group(
|
||||||
|
backend=self.backend,
|
||||||
|
init_method=url,
|
||||||
|
rank=world_rank,
|
||||||
|
world_size=world_size)
|
||||||
|
|
||||||
|
def _setup_training(self):
|
||||||
|
logger.debug("Creating model")
|
||||||
|
self.model = self.model_creator(self.config)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
self.model.cuda())
|
||||||
|
else:
|
||||||
|
self.model = torch.nn.parallel.DistributedDataParallelCPU(
|
||||||
|
self.model)
|
||||||
|
|
||||||
|
logger.debug("Creating optimizer")
|
||||||
|
self.criterion, self.optimizer = self.optimizer_creator(
|
||||||
|
self.model, self.config)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.criterion = self.criterion.cuda()
|
||||||
|
|
||||||
|
logger.debug("Creating dataset")
|
||||||
|
self.training_set, self.validation_set = self.data_creator(self.config)
|
||||||
|
|
||||||
|
# TODO: make num_workers configurable
|
||||||
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
|
self.training_set)
|
||||||
|
self.train_loader = torch.utils.data.DataLoader(
|
||||||
|
self.training_set,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=(self.train_sampler is None),
|
||||||
|
num_workers=2,
|
||||||
|
pin_memory=False,
|
||||||
|
sampler=self.train_sampler)
|
||||||
|
|
||||||
|
self.validation_sampler = (
|
||||||
|
torch.utils.data.distributed.DistributedSampler(
|
||||||
|
self.validation_set))
|
||||||
|
self.validation_loader = torch.utils.data.DataLoader(
|
||||||
|
self.validation_set,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=(self.validation_sampler is None),
|
||||||
|
num_workers=2,
|
||||||
|
pin_memory=False,
|
||||||
|
sampler=self.validation_sampler)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Runs a training epoch and updates the model parameters."""
|
||||||
|
logger.debug("Starting step")
|
||||||
|
self.train_sampler.set_epoch(self.epoch)
|
||||||
|
return super(DistributedPyTorchRunner, self).step()
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
"""Returns the state of the runner."""
|
||||||
|
return {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"model": self.model.module.state_dict(),
|
||||||
|
"optimizer": self.optimizer.state_dict(),
|
||||||
|
"stats": self.stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
"""Sets the state of the model."""
|
||||||
|
# TODO: restore timer stats
|
||||||
|
self.model.module.load_state_dict(state["model"])
|
||||||
|
self.optimizer.load_state_dict(state["optimizer"])
|
||||||
|
self.epoch = state["stats"]["epoch"]
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Attempts to shut down the worker."""
|
||||||
|
super(DistributedPyTorchRunner, self).shutdown()
|
||||||
|
dist.destroy_process_group()
|
|
@ -3,9 +3,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -15,28 +13,23 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PyTorchRunner(object):
|
class PyTorchRunner(object):
|
||||||
"""Manages a distributed PyTorch model replica"""
|
"""Manages a PyTorch model for training."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_creator,
|
model_creator,
|
||||||
data_creator,
|
data_creator,
|
||||||
optimizer_creator,
|
optimizer_creator,
|
||||||
config=None,
|
config=None,
|
||||||
batch_size=16,
|
batch_size=16):
|
||||||
backend="gloo"):
|
|
||||||
"""Initializes the runner.
|
"""Initializes the runner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_creator (dict -> torch.nn.Module): creates the model using
|
model_creator (dict -> torch.nn.Module): see pytorch_trainer.py.
|
||||||
the config.
|
data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py.
|
||||||
data_creator (dict -> Dataset, Dataset): creates the training and
|
|
||||||
validation data sets using the config.
|
|
||||||
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
|
||||||
creates the loss and optimizer using the model and the config.
|
see pytorch_trainer.py.
|
||||||
config (dict): configuration passed to 'model_creator',
|
config (dict): see pytorch_trainer.py.
|
||||||
'data_creator', and 'optimizer_creator'.
|
batch_size (int): see pytorch_trainer.py.
|
||||||
batch_size (int): batch size used in an update.
|
|
||||||
backend (string): backend used by distributed PyTorch.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.model_creator = model_creator
|
self.model_creator = model_creator
|
||||||
|
@ -44,7 +37,6 @@ class PyTorchRunner(object):
|
||||||
self.optimizer_creator = optimizer_creator
|
self.optimizer_creator = optimizer_creator
|
||||||
self.config = {} if config is None else config
|
self.config = {} if config is None else config
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.backend = backend
|
|
||||||
self.verbose = True
|
self.verbose = True
|
||||||
|
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
@ -56,82 +48,45 @@ class PyTorchRunner(object):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup(self, url, world_rank, world_size):
|
def setup(self):
|
||||||
"""Connects to the distributed PyTorch backend and initializes the model.
|
"""Initializes the model."""
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): the URL used to connect to distributed PyTorch.
|
|
||||||
world_rank (int): the index of the runner.
|
|
||||||
world_size (int): the total number of runners.
|
|
||||||
"""
|
|
||||||
self._setup_distributed_pytorch(url, world_rank, world_size)
|
|
||||||
self._setup_training()
|
|
||||||
|
|
||||||
def _setup_distributed_pytorch(self, url, world_rank, world_size):
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
||||||
with self._timers["setup_proc"]:
|
|
||||||
self.world_rank = world_rank
|
|
||||||
logger.debug(
|
|
||||||
"Connecting to {} world_rank: {} world_size: {}".format(
|
|
||||||
url, world_rank, world_size))
|
|
||||||
logger.debug("using {}".format(self.backend))
|
|
||||||
dist.init_process_group(
|
|
||||||
backend=self.backend,
|
|
||||||
init_method=url,
|
|
||||||
rank=world_rank,
|
|
||||||
world_size=world_size)
|
|
||||||
|
|
||||||
def _setup_training(self):
|
|
||||||
logger.debug("Creating model")
|
logger.debug("Creating model")
|
||||||
self.model = self.model_creator(self.config)
|
self.model = self.model_creator(self.config)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.model = torch.nn.parallel.DistributedDataParallel(
|
self.model = self.model.cuda()
|
||||||
self.model.cuda())
|
|
||||||
else:
|
|
||||||
self.model = torch.nn.parallel.DistributedDataParallelCPU(
|
|
||||||
self.model)
|
|
||||||
|
|
||||||
logger.debug("Creating optimizer")
|
logger.debug("Creating optimizer")
|
||||||
self.criterion, self.optimizer = self.optimizer_creator(
|
self.criterion, self.optimizer = self.optimizer_creator(
|
||||||
self.model, self.config)
|
self.model, self.config)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.criterion = self.criterion.cuda()
|
self.criterion = self.criterion.cuda()
|
||||||
|
|
||||||
logger.debug("Creating dataset")
|
logger.debug("Creating dataset")
|
||||||
self.training_set, self.validation_set = self.data_creator(self.config)
|
self.training_set, self.validation_set = self.data_creator(self.config)
|
||||||
|
|
||||||
# TODO: make num_workers configurable
|
|
||||||
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
||||||
self.training_set)
|
|
||||||
self.train_loader = torch.utils.data.DataLoader(
|
self.train_loader = torch.utils.data.DataLoader(
|
||||||
self.training_set,
|
self.training_set,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=(self.train_sampler is None),
|
shuffle=True,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
pin_memory=False,
|
pin_memory=False)
|
||||||
sampler=self.train_sampler)
|
|
||||||
|
|
||||||
self.validation_sampler = (
|
|
||||||
torch.utils.data.distributed.DistributedSampler(
|
|
||||||
self.validation_set))
|
|
||||||
self.validation_loader = torch.utils.data.DataLoader(
|
self.validation_loader = torch.utils.data.DataLoader(
|
||||||
self.validation_set,
|
self.validation_set,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=(self.validation_sampler is None),
|
shuffle=True,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
pin_memory=False,
|
pin_memory=False)
|
||||||
sampler=self.validation_sampler)
|
|
||||||
|
|
||||||
def get_node_ip(self):
|
def get_node_ip(self):
|
||||||
"""Returns the IP address of the current node"""
|
"""Returns the IP address of the current node."""
|
||||||
return ray.services.get_node_ip_address()
|
return ray.services.get_node_ip_address()
|
||||||
|
|
||||||
def step(self):
|
def find_free_port(self):
|
||||||
"""Runs a training epoch and updates the model parameters"""
|
"""Finds a free port on the current node."""
|
||||||
logger.debug("Starting step")
|
return utils.find_free_port()
|
||||||
self.train_sampler.set_epoch(self.epoch)
|
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Runs a training epoch and updates the model parameters."""
|
||||||
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
|
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
|
||||||
with self._timers["training"]:
|
with self._timers["training"]:
|
||||||
train_stats = utils.train(self.train_loader, self.model,
|
train_stats = utils.train(self.train_loader, self.model,
|
||||||
|
@ -144,7 +99,7 @@ class PyTorchRunner(object):
|
||||||
return train_stats
|
return train_stats
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Evaluates the model on the validation data set"""
|
"""Evaluates the model on the validation data set."""
|
||||||
with self._timers["validation"]:
|
with self._timers["validation"]:
|
||||||
validation_stats = utils.validate(self.validation_loader,
|
validation_stats = utils.validate(self.validation_loader,
|
||||||
self.model, self.criterion)
|
self.model, self.criterion)
|
||||||
|
@ -153,7 +108,7 @@ class PyTorchRunner(object):
|
||||||
return validation_stats
|
return validation_stats
|
||||||
|
|
||||||
def stats(self):
|
def stats(self):
|
||||||
"""Returns a dictionary of statistics collected"""
|
"""Returns a dictionary of statistics collected."""
|
||||||
stats = {"epoch": self.epoch}
|
stats = {"epoch": self.epoch}
|
||||||
for k, t in self._timers.items():
|
for k, t in self._timers.items():
|
||||||
stats[k + "_time_mean"] = t.mean
|
stats[k + "_time_mean"] = t.mean
|
||||||
|
@ -162,7 +117,7 @@ class PyTorchRunner(object):
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
"""Returns the state of the runner"""
|
"""Returns the state of the runner."""
|
||||||
return {
|
return {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"model": self.model.state_dict(),
|
"model": self.model.state_dict(),
|
||||||
|
@ -171,12 +126,20 @@ class PyTorchRunner(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
"""Sets the state of the model"""
|
"""Sets the state of the model."""
|
||||||
# TODO: restore timer stats
|
# TODO: restore timer stats
|
||||||
self.model.load_state_dict(state["model"])
|
self.model.load_state_dict(state["model"])
|
||||||
self.optimizer.load_state_dict(state["optimizer"])
|
self.optimizer.load_state_dict(state["optimizer"])
|
||||||
self.epoch = state["stats"]["epoch"]
|
self.epoch = state["stats"]["epoch"]
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Attempts to shut down the worker"""
|
"""Attempts to shut down the worker."""
|
||||||
dist.destroy_process_group()
|
del self.validation_loader
|
||||||
|
del self.validation_set
|
||||||
|
del self.train_loader
|
||||||
|
del self.training_set
|
||||||
|
del self.criterion
|
||||||
|
del self.optimizer
|
||||||
|
del self.model
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -3,13 +3,15 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||||
|
from ray.experimental.sgd.pytorch.distributed_pytorch_runner import (
|
||||||
|
DistributedPyTorchRunner)
|
||||||
from ray.experimental.sgd.pytorch import utils
|
from ray.experimental.sgd.pytorch import utils
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -51,10 +53,11 @@ class PyTorchTrainer(object):
|
||||||
"""
|
"""
|
||||||
# TODO: add support for mixed precision
|
# TODO: add support for mixed precision
|
||||||
# TODO: add support for callbacks
|
# TODO: add support for callbacks
|
||||||
if sys.platform == "darwin":
|
if num_replicas > 1 and not dist.is_available():
|
||||||
raise Exception(
|
raise ValueError(
|
||||||
("Distributed PyTorch is not supported on macOS. For more "
|
("Distributed PyTorch is not supported on macOS. "
|
||||||
"information, see "
|
"To run without distributed PyTorch, set 'num_replicas=1'. "
|
||||||
|
"For more information, see "
|
||||||
"https://github.com/pytorch/examples/issues/467."))
|
"https://github.com/pytorch/examples/issues/467."))
|
||||||
|
|
||||||
self.model_creator = model_creator
|
self.model_creator = model_creator
|
||||||
|
@ -68,40 +71,55 @@ class PyTorchTrainer(object):
|
||||||
if backend == "auto":
|
if backend == "auto":
|
||||||
backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo"
|
backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo"
|
||||||
|
|
||||||
Runner = ray.remote(
|
if num_replicas == 1:
|
||||||
num_cpus=resources_per_replica.num_cpus,
|
# Generate actor class
|
||||||
num_gpus=resources_per_replica.num_gpus,
|
Runner = ray.remote(
|
||||||
resources=resources_per_replica.resources)(PyTorchRunner)
|
num_cpus=resources_per_replica.num_cpus,
|
||||||
|
num_gpus=resources_per_replica.num_gpus,
|
||||||
batch_size_per_replica = batch_size // num_replicas
|
resources=resources_per_replica.resources)(PyTorchRunner)
|
||||||
if batch_size % num_replicas > 0:
|
# Start workers
|
||||||
new_batch_size = batch_size_per_replica * num_replicas
|
self.workers = [
|
||||||
logger.warn(
|
Runner.remote(model_creator, data_creator, optimizer_creator,
|
||||||
("Changing batch size from {old_batch_size} to "
|
self.config, batch_size)
|
||||||
"{new_batch_size} to evenly distribute batches across "
|
]
|
||||||
"{num_replicas} replicas.").format(
|
# Get setup tasks in order to throw errors on failure
|
||||||
old_batch_size=batch_size,
|
ray.get(self.workers[0].setup.remote())
|
||||||
new_batch_size=new_batch_size,
|
else:
|
||||||
num_replicas=num_replicas))
|
# Geneate actor class
|
||||||
|
Runner = ray.remote(
|
||||||
self.workers = [
|
num_cpus=resources_per_replica.num_cpus,
|
||||||
Runner.remote(model_creator, data_creator, optimizer_creator,
|
num_gpus=resources_per_replica.num_gpus,
|
||||||
self.config, batch_size_per_replica, backend)
|
resources=resources_per_replica.resources)(
|
||||||
for i in range(num_replicas)
|
DistributedPyTorchRunner)
|
||||||
]
|
# Compute batch size per replica
|
||||||
|
batch_size_per_replica = batch_size // num_replicas
|
||||||
ip = ray.get(self.workers[0].get_node_ip.remote())
|
if batch_size % num_replicas > 0:
|
||||||
port = utils.find_free_port()
|
new_batch_size = batch_size_per_replica * num_replicas
|
||||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
logger.warn(
|
||||||
|
("Changing batch size from {old_batch_size} to "
|
||||||
# Get setup tasks in order to throw errors on failure
|
"{new_batch_size} to evenly distribute batches across "
|
||||||
ray.get([
|
"{num_replicas} replicas.").format(
|
||||||
worker.setup.remote(address, i, len(self.workers))
|
old_batch_size=batch_size,
|
||||||
for i, worker in enumerate(self.workers)
|
new_batch_size=new_batch_size,
|
||||||
])
|
num_replicas=num_replicas))
|
||||||
|
# Start workers
|
||||||
|
self.workers = [
|
||||||
|
Runner.remote(model_creator, data_creator, optimizer_creator,
|
||||||
|
self.config, batch_size_per_replica, backend)
|
||||||
|
for i in range(num_replicas)
|
||||||
|
]
|
||||||
|
# Compute URL for initializing distributed PyTorch
|
||||||
|
ip = ray.get(self.workers[0].get_node_ip.remote())
|
||||||
|
port = ray.get(self.workers[0].find_free_port.remote())
|
||||||
|
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||||
|
# Get setup tasks in order to throw errors on failure
|
||||||
|
ray.get([
|
||||||
|
worker.setup.remote(address, i, len(self.workers))
|
||||||
|
for i, worker in enumerate(self.workers)
|
||||||
|
])
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Runs a training epoch"""
|
"""Runs a training epoch."""
|
||||||
with self.optimizer_timer:
|
with self.optimizer_timer:
|
||||||
worker_stats = ray.get([w.step.remote() for w in self.workers])
|
worker_stats = ray.get([w.step.remote() for w in self.workers])
|
||||||
|
|
||||||
|
@ -111,7 +129,7 @@ class PyTorchTrainer(object):
|
||||||
return train_stats
|
return train_stats
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""Evaluates the model on the validation data set"""
|
"""Evaluates the model on the validation data set."""
|
||||||
worker_stats = ray.get([w.validate.remote() for w in self.workers])
|
worker_stats = ray.get([w.validate.remote() for w in self.workers])
|
||||||
validation_stats = worker_stats[0].copy()
|
validation_stats = worker_stats[0].copy()
|
||||||
validation_stats["validation_loss"] = np.mean(
|
validation_stats["validation_loss"] = np.mean(
|
||||||
|
@ -119,32 +137,25 @@ class PyTorchTrainer(object):
|
||||||
return validation_stats
|
return validation_stats
|
||||||
|
|
||||||
def get_model(self):
|
def get_model(self):
|
||||||
"""Returns the learned model"""
|
"""Returns the learned model."""
|
||||||
model = self.model_creator(self.config)
|
model = self.model_creator(self.config)
|
||||||
state = ray.get(self.workers[0].get_state.remote())
|
state = ray.get(self.workers[0].get_state.remote())
|
||||||
|
model.load_state_dict(state["model"])
|
||||||
# Remove module. prefix added by distrbuted pytorch
|
|
||||||
state_dict = {
|
|
||||||
k.replace("module.", ""): v
|
|
||||||
for k, v in state["model"].items()
|
|
||||||
}
|
|
||||||
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save(self, ckpt):
|
def save(self, ckpt):
|
||||||
"""Saves the model at the provided checkpoint"""
|
"""Saves the model at the provided checkpoint."""
|
||||||
state = ray.get(self.workers[0].get_state.remote())
|
state = ray.get(self.workers[0].get_state.remote())
|
||||||
torch.save(state, ckpt)
|
torch.save(state, ckpt)
|
||||||
|
|
||||||
def restore(self, ckpt):
|
def restore(self, ckpt):
|
||||||
"""Restores the model from the provided checkpoint"""
|
"""Restores the model from the provided checkpoint."""
|
||||||
state = torch.load(ckpt)
|
state = torch.load(ckpt)
|
||||||
state_id = ray.put(state)
|
state_id = ray.put(state)
|
||||||
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
|
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shuts down workers and releases resources"""
|
"""Shuts down workers and releases resources."""
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
worker.shutdown.remote()
|
worker.shutdown.remote()
|
||||||
worker.__ray_terminate__.remote()
|
worker.__ray_terminate__.remote()
|
||||||
|
|
|
@ -196,7 +196,7 @@ def find_free_port():
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter(object):
|
class AverageMeter(object):
|
||||||
"""Computes and stores the average and current value"""
|
"""Computes and stores the average and current value."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
|
@ -4,9 +4,9 @@ from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
|
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
|
||||||
|
@ -15,14 +15,14 @@ from ray.experimental.sgd.tests.pytorch_utils import (
|
||||||
model_creator, optimizer_creator, data_creator)
|
model_creator, optimizer_creator, data_creator)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif( # noqa: F811
|
@pytest.mark.parametrize( # noqa: F811
|
||||||
sys.platform == "darwin", reason="Doesn't work on macOS.")
|
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||||
def test_train(ray_start_2_cpus): # noqa: F811
|
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||||
trainer = PyTorchTrainer(
|
trainer = PyTorchTrainer(
|
||||||
model_creator,
|
model_creator,
|
||||||
data_creator,
|
data_creator,
|
||||||
optimizer_creator,
|
optimizer_creator,
|
||||||
num_replicas=2,
|
num_replicas=num_replicas,
|
||||||
resources_per_replica=Resources(num_cpus=1))
|
resources_per_replica=Resources(num_cpus=1))
|
||||||
train_loss1 = trainer.train()["train_loss"]
|
train_loss1 = trainer.train()["train_loss"]
|
||||||
validation_loss1 = trainer.validate()["validation_loss"]
|
validation_loss1 = trainer.validate()["validation_loss"]
|
||||||
|
@ -37,14 +37,14 @@ def test_train(ray_start_2_cpus): # noqa: F811
|
||||||
assert validation_loss2 <= validation_loss1
|
assert validation_loss2 <= validation_loss1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif( # noqa: F811
|
@pytest.mark.parametrize( # noqa: F811
|
||||||
sys.platform == "darwin", reason="Doesn't work on macOS.")
|
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||||
def test_save_and_restore(ray_start_2_cpus): # noqa: F811
|
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||||
trainer1 = PyTorchTrainer(
|
trainer1 = PyTorchTrainer(
|
||||||
model_creator,
|
model_creator,
|
||||||
data_creator,
|
data_creator,
|
||||||
optimizer_creator,
|
optimizer_creator,
|
||||||
num_replicas=2,
|
num_replicas=num_replicas,
|
||||||
resources_per_replica=Resources(num_cpus=1))
|
resources_per_replica=Resources(num_cpus=1))
|
||||||
trainer1.train()
|
trainer1.train()
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def test_save_and_restore(ray_start_2_cpus): # noqa: F811
|
||||||
model_creator,
|
model_creator,
|
||||||
data_creator,
|
data_creator,
|
||||||
optimizer_creator,
|
optimizer_creator,
|
||||||
num_replicas=2,
|
num_replicas=num_replicas,
|
||||||
resources_per_replica=Resources(num_cpus=1))
|
resources_per_replica=Resources(num_cpus=1))
|
||||||
trainer2.restore(filename)
|
trainer2.restore(filename)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue