[sgd] Semantic Segmentation Example (#7825)

* better_example

* test

* improve some usability things

* submit

* fix

* making a segmentation example

* segmentation_example

* segmentation

* device

* flake

* Update python/ray/util/sgd/torch/training_operator.py

* uti

* finished_example

* block

* format

* locationg

* fix

* ok

* revert

* segmentation

* lint_and_test

* address_comments
This commit is contained in:
Richard Liaw 2020-04-10 20:35:45 -07:00 committed by GitHub
parent 0b4e09da76
commit dd63178e91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 783 additions and 71 deletions

View file

@ -610,8 +610,8 @@ You can see more details in the `benchmarking README <https://github.com/ray-pro
DISCLAIMER: RaySGD does not provide any custom communication primitives. If you see any performance issues, you may need to file them on the PyTorch github repository.
Debugging
---------
Debugging/Tips
--------------
Here's some simple tips on how to debug the TorchTrainer.
@ -657,6 +657,32 @@ Try using a profiler. Either use:
or use `Python profiling <https://docs.python.org/3/library/debug.html>`_.
**My creator functions download data, and I don't want multiple processes downloading to the same path at once.**
Use ``filelock`` within the creator functions to create locks for critical regions. For example:
.. code-block:: python
import os
import tempfile
from filelock import FileLock
def create_dataset(config):
dataset_path = config["dataset_path"]
# Create a critical region of the code
# This will take a longer amount of time to download the data at first.
# Other processes will block at the ``with`` statement.
# After downloading, this code block becomes very fast.
with FileLock(os.path.join(tempfile.gettempdir(), "download_data.lock")):
if not os.path.exists(dataset_path):
download_data(dataset_path)
# load_data is assumed to safely support concurrent reads.
data = load_data(dataset_path)
return DataLoader(data)
**I get a 'socket timeout' error during training.**
Try increasing the length of the NCCL timeout. The current timeout is 10 seconds.
@ -688,6 +714,9 @@ to contribute an example, feel free to create a `pull request here <https://gith
- `TorchTrainer and RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/tune_example.py>`__:
Simple example of hyperparameter tuning with Ray's TorchTrainer.
- `Semantic Segmentation example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py>`__:
Fine-tuning a ResNet50 model on VOC with Batch Norm.
- `CIFAR10 example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py>`__:
Training a ResNet18 model on CIFAR10.

View file

@ -47,6 +47,19 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
def test_resize(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=1)
trainer.train(num_steps=1)
trainer.max_replicas = 2
results = trainer.train(num_steps=1, reduce_results=False)
assert len(results) == 2
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,

View file

@ -1,10 +1,9 @@
from datetime import timedelta
import collections
import logging
import io
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
@ -12,7 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
import ray
from ray.util.sgd.torch.torch_runner import TorchRunner, _remind_gpu_usage
from ray.util.sgd.torch.torch_runner import TorchRunner
logger = logging.getLogger(__name__)
@ -44,25 +43,23 @@ class DistributedTorchRunner(TorchRunner):
self.backend = backend
self.wrap_ddp = wrap_ddp
self.add_dist_sampler = add_dist_sampler
self.world_rank = None
def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.
def setup(self):
raise RuntimeError("Need to call setup commands separately.")
def setup_process_group(self, url, world_rank, world_size):
"""Connects the distributed PyTorch backend.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
_remind_gpu_usage(self.use_gpu, is_chief=world_rank == 0)
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()
def _setup_distributed_pytorch(self, url, world_rank, world_size):
self.world_rank = world_rank
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
logger.debug(
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
@ -77,48 +74,26 @@ class DistributedTorchRunner(TorchRunner):
world_size=world_size,
timeout=timeout)
self.device_ids = None
def setup_ddp_and_operator(self):
"""Runs distributed coordination components.
This helps avoid timeouts due to creator functions (perhaps
downloading data or models).
"""
device_ids = None
if self.use_gpu and torch.cuda.is_available():
# https://github.com/allenai/allennlp/issues/1090
self.set_cuda_device_id()
device_ids = self.get_device_ids()
def set_cuda_device_id(self):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
self.device_ids = [0]
def _setup_training(self):
logger.debug("Loading data.")
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
if self.use_gpu and torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer.")
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, collections.Iterable):
self.optimizers = [self.optimizers]
self._create_schedulers_if_available()
self._try_setup_apex()
self._create_loss()
# Wrap dataloaders
self._wrap_dataloaders()
training_models = self.models
if self.wrap_ddp:
# This needs to happen after apex
training_models = [
DistributedDataParallel(model, device_ids=self.device_ids)
DistributedDataParallel(model, device_ids=device_ids)
for model in self.models
]
self.training_operator = self.training_operator_cls(
self.config,
models=training_models,
@ -128,14 +103,33 @@ class DistributedTorchRunner(TorchRunner):
validation_loader=self.validation_loader,
world_rank=self.world_rank,
schedulers=self.schedulers,
device_ids=self.device_ids,
device_ids=device_ids,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
def _initialize_dataloaders(self):
super(DistributedTorchRunner, self)._initialize_dataloaders()
def get_device_ids(self):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
return [0]
def load_state_stream(self, byte_obj):
"""Loads a bytes object the training state dict.
This is needed because we don't want to deserialize the tensor
onto the same device (which is from the driver process). We want to
map it onto the actor's specific device.
From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769
"""
_buffer = io.BytesIO(byte_obj)
to_gpu = self.use_gpu and torch.cuda.is_available()
state_dict = torch.load(
_buffer,
map_location=("cpu" if not to_gpu else
lambda storage, loc: storage.cuda()))
return self.load_state_dict(state_dict)
def _wrap_dataloaders(self):
def with_sampler(loader):
# Automatically set the DistributedSampler
data_loader_args = {
@ -215,11 +209,12 @@ class LocalDistributedRunner(DistributedTorchRunner):
# TODO: we should make sure this NEVER dies.
self.local_device = "0"
global _dummy_actor
if not self.is_actor() and _dummy_actor is None:
_dummy_actor = ray.remote(
num_cpus=num_cpus,
num_gpus=num_gpus,
resources={"node:" + ip: 0.1})(_DummyActor).remote()
if not self.is_actor():
if _dummy_actor is None:
_dummy_actor = ray.remote(
num_cpus=num_cpus,
num_gpus=num_gpus,
resources={"node:" + ip: 0.1})(_DummyActor).remote()
self.local_device = ray.get(_dummy_actor.cuda_devices.remote())
@ -235,16 +230,18 @@ class LocalDistributedRunner(DistributedTorchRunner):
# interaction.
_init_cuda_context()
os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device
if self.local_device:
try:
torch.cuda.set_device(int(self.local_device))
except RuntimeError:
logger.error("This happens if cuda is not initialized.")
raise
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
def set_cuda_device_id(self):
self.device_ids = [int(self.local_device)]
def get_device_ids(self):
return [int(self.local_device)]
def shutdown(self, cleanup=True):
super(LocalDistributedRunner, self).shutdown()

View file

@ -73,7 +73,7 @@ if __name__ == "__main__":
"--address",
required=False,
type=str,
help="the address to use for Redis")
help="the address to use for connecting to the Ray cluster")
parser.add_argument(
"--num-workers",
"-n",

View file

@ -0,0 +1,34 @@
# Semantic segmentation reference -> RaySGD
Original scripts are taken from: https://github.com/pytorch/vision/tree/master/references/segmentation.
On a single node, you can leverage Distributed Data Parallelism (DDP) by simply using the `-n` parameter. This will automatically parallelize your training across `n` GPUs. As listed from the original repository, below are standard hyperparameters.
```bash
pip install tqdm pycocotools
```
## fcn_resnet101
```
python train_segmentation.py -n 4 --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss
```
## deeplabv3_resnet101
```
python train_segmentation.py train_segmentation.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
```
## Scaling up
This example can be executed on AWS by running
```
ray submit cluster.yaml train_segmentation.py -- start --args="--lr 0.02 ..."
```
To leverage multiple GPUs (beyond a single node), be sure to add an `address` parameter:
```
ray submit cluster.yaml train_segmentation.py -- start --args="--address='auto' --lr 0.02 ..."
```

View file

@ -0,0 +1,116 @@
# flake8: noqa
import copy
import torch
import torch.utils.data
import torchvision
from PIL import Image
import os
from pycocotools import mask as coco_mask
from ray.util.sgd.torch.examples.segmentation.transforms import Compose
class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, anno):
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
return image, anno
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
return image, anno
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask(object):
def __call__(self, image, anno):
w, h = image.size
segmentations = [obj["segmentation"] for obj in anno]
cats = [obj["category_id"] for obj in anno]
if segmentations:
masks = convert_coco_poly_to_mask(segmentations, h, w)
cats = torch.as_tensor(cats, dtype=masks.dtype)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
else:
target = torch.zeros((h, w), dtype=torch.uint8)
target = Image.fromarray(target.numpy())
return image, target
def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000
assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)
dataset = torch.utils.data.Subset(dataset, ids)
return dataset
def get_coco(root, image_set, transforms):
PATHS = {
"train": ("train2017",
os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations",
"instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [
0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7,
72
]
transforms = Compose([
FilterAndRemapCocoCategories(CAT_LIST, remap=True),
ConvertCocoPolysToMask(), transforms
])
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
dataset = torchvision.datasets.CocoDetection(
img_folder, ann_file, transforms=transforms)
if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
return dataset

View file

@ -0,0 +1,61 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: sgd-coco-pytorch
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers. min_workers default to 0.
min_workers: 1
initial_workers: 1
max_workers: 1
target_utilization_fraction: 0.9
# Cloud-provider specific configuration.
provider:
type: aws
region: us-east-1
availability_zone: us-east-1c
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
head_node:
InstanceType: p3.8xlarge
ImageId: ami-0698bcaf8bd9ef56d
InstanceMarketOptions:
MarketType: spot
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
VolumeSize: 300
worker_nodes:
InstanceType: p3.8xlarge
ImageId: ami-0698bcaf8bd9ef56d
InstanceMarketOptions:
MarketType: spot
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
VolumeSize: 300
setup_commands:
# This replaces the standard anaconda Ray installation
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
# Uncomment this and the filemount to update the Ray installation with your local Ray code
# - rm -rf ./anaconda3/lib/python3.6/site-packages/ray/util/sgd/
# - cp -rf ~/sgd ./anaconda3/lib/python3.6/site-packages/ray/util/
- pip install -q tqdm
# Installing this without -U to make sure we don't replace the existing Ray installation
- pip install ray[rllib]
- pip install -U ipdb torch torchvision pycocotools
# Install Apex if needed.
- git clone https://github.com/NVIDIA/apex; cd apex && pip install -v --no-cache-dir ./ || true
file_mounts: {
# # This should point to ray/python/ray/util/sgd.
# ~/sgd: ../../../../sgd,
}

View file

@ -0,0 +1,259 @@
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
import ray
from ray.util.sgd.torch.examples.segmentation.coco_utils import get_coco
import ray.util.sgd.torch.examples.segmentation.transforms as T
import ray.util.sgd.torch.examples.segmentation.utils as utils
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd import TorchTrainer
try:
from apex import amp
except ImportError:
amp = None
def get_dataset(name,
image_set,
transform,
num_classes_only=False,
download="auto"):
def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(
*args, mode="segmentation", **kwargs)
paths = {
"voc": (os.path.expanduser("~/datasets01/VOC/060817/"),
torchvision.datasets.VOCSegmentation, 21),
"voc_aug": (os.path.expanduser("~/datasets01/SBDD/072318/"), sbd, 21),
"coco": (os.path.expanduser("~/datasets01/COCO/022719/"), get_coco, 21)
}
p, ds_fn, num_classes = paths[name]
if num_classes_only:
return None, num_classes
if download == "auto" and os.path.exists(p):
download = False
try:
ds = ds_fn(
p, download=download, image_set=image_set, transforms=transform)
except RuntimeError:
print("data loading failed. Retrying this.")
ds = ds_fn(p, download=True, image_set=image_set, transforms=transform)
return ds, num_classes
def data_creator(config):
# Within a machine, this code runs synchronously.
dataset, num_classes = get_dataset(
args.dataset, "train", get_transform(train=True))
config["num_classes"] = num_classes
dataset_test, _ = get_dataset(
args.dataset, "val", get_transform(train=False))
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.data_workers,
collate_fn=utils.collate_fn,
drop_last=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=1,
num_workers=args.data_workers,
collate_fn=utils.collate_fn)
return data_loader, data_loader_test
def get_transform(train):
base_size = 520
crop_size = 480
min_size = int((0.5 if train else 1.0) * base_size)
max_size = int((2.0 if train else 1.0) * base_size)
transforms = []
transforms.append(T.RandomResize(min_size, max_size))
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomCrop(crop_size))
transforms.append(T.ToTensor())
transforms.append(
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
return T.Compose(transforms)
def criterion(inputs, target):
losses = {}
for name, x in inputs.items():
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
if len(losses) == 1:
return losses["out"]
return losses["out"] + 0.5 * losses["aux"]
class SegOperator(TrainingOperator):
def train_batch(self, batch, batch_info):
image, target = batch
image, target = image.to(self.device), target.to(self.device)
output = self.model(image)
loss = criterion(output, target)
self.optimizer.zero_grad()
if self.use_fp16 and amp:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.optimizer.step()
lr = self.optimizer.param_groups[0]["lr"]
return {"loss": loss.item(), "lr": lr, "num_samples": len(batch)}
def validate(self, data_loader, info=None):
self.model.eval()
confmat = utils.ConfusionMatrix(self.config["num_classes"])
with torch.no_grad():
for image, target in data_loader:
image, target = image.to(self.device), target.to(self.device)
output = self.model(image)
output = output["out"]
confmat.update(target.flatten(), output.argmax(1).flatten())
confmat.reduce_from_all_processes()
return confmat
def optimizer_creator(model, config):
args = config["args"]
params_to_optimize = [
{
"params": [
p for p in model.backbone.parameters() if p.requires_grad
]
},
{
"params": [
p for p in model.classifier.parameters() if p.requires_grad
]
},
]
if args.aux_loss:
params = [
p for p in model.aux_classifier.parameters() if p.requires_grad
]
params_to_optimize.append({"params": params, "lr": args.lr * 10})
return torch.optim.SGD(
params_to_optimize,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
def model_creator(config):
args = config["args"]
model = torchvision.models.segmentation.__dict__[args.model](
num_classes=config["num_classes"],
aux_loss=args.aux_loss,
pretrained=args.pretrained)
if config["num_workers"] > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
print(args)
start_time = time.time()
config = {"args": args, "num_workers": args.num_workers}
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
training_operator_cls=SegOperator,
use_tqdm=True,
use_fp16=True,
num_workers=config["num_workers"],
config=config,
use_gpu=torch.cuda.is_available())
for epoch in range(args.epochs):
trainer.train()
confmat = trainer.validate(reduce_results=False)[0]
print(confmat)
state_dict = trainer.state_dict()
state_dict.update(epoch=epoch, args=args)
torch.save(state_dict,
os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(
description="PyTorch Segmentation Training with RaySGD")
parser.add_argument(
"--address",
required=False,
default=None,
help="the address to use for connecting to a Ray cluster.")
parser.add_argument("--dataset", default="voc", help="dataset")
parser.add_argument("--model", default="fcn_resnet101", help="model")
parser.add_argument(
"--aux-loss", action="store_true", help="auxiliar loss")
parser.add_argument("--device", default="cuda", help="device")
parser.add_argument("-b", "--batch-size", default=8, type=int)
parser.add_argument(
"-n", "--num-workers", default=1, type=int, help="GPU parallelism")
parser.add_argument(
"--epochs",
default=30,
type=int,
metavar="N",
help="number of total epochs to run")
parser.add_argument(
"--data-workers",
default=16,
type=int,
metavar="N",
help="number of data loading workers (default: 16)")
parser.add_argument(
"--lr", default=0.01, type=float, help="initial learning rate")
parser.add_argument(
"--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay")
parser.add_argument("--output-dir", default=".", help="path where to save")
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
ray.init(address=args.address)
main(args)

View file

@ -0,0 +1,93 @@
# flake8: noqa
import numpy as np
from PIL import Image
import random
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
ow, oh = img.size
padh = size - oh if oh < size else 0
padw = size - ow if ow < size else 0
img = F.pad(img, (0, 0, padw, padh), fill=fill)
return img
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class RandomResize(object):
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
target = F.resize(target, size, interpolation=Image.NEAREST)
return image, target
class RandomHorizontalFlip(object):
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = F.hflip(image)
target = F.hflip(target)
return image, target
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
image = F.crop(image, *crop_params)
target = F.crop(target, *crop_params)
return image, target
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = F.center_crop(image, self.size)
target = F.center_crop(target, self.size)
return image, target
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
target = torch.as_tensor(np.asarray(target), dtype=torch.int64)
return image, target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = F.normalize(image, mean=self.mean, std=self.std)
return image, target

View file

@ -0,0 +1,70 @@
# flake8: noqa
from collections import defaultdict, deque
import datetime
import math
import time
import torch
import torch.distributed as dist
import errno
import os
class ConfusionMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return acc_global, acc, iu
def reduce_from_all_processes(self):
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
torch.distributed.barrier()
torch.distributed.all_reduce(self.mat)
def __str__(self):
acc_global, acc, iu = self.compute()
return ('global correct: {:.1f}\n'
'average row correct: {}\n'
'IoU: {}\n'
'mean IoU: {:.1f}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
iu.mean().item() * 100)
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images), ) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
return batched_imgs
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets

View file

@ -7,6 +7,7 @@ import itertools
import os
import tempfile
import torch
import torch.nn as nn
import ray
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS
@ -23,14 +24,6 @@ except ImportError:
pass
def _remind_gpu_usage(use_gpu, is_chief):
if not is_chief:
return
if not use_gpu and torch.cuda.is_available():
logger.info("GPUs detected but not using them. Set `use_gpu` to "
"enable GPU usage. ")
class TorchRunner:
"""Manages a PyTorch model for training.
@ -79,6 +72,7 @@ class TorchRunner:
self.schedulers = None
self.train_loader = None
self.validation_loader = None
self.training_operator = None
self.use_gpu = use_gpu
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
@ -149,24 +143,35 @@ class TorchRunner:
self.models, self.optimizers, **self.apex_args)
def setup(self):
"""Initializes the model."""
_remind_gpu_usage(self.use_gpu, is_chief=True)
"""Merges setup_components and setup_operator in one call."""
self.setup_components()
self.setup_operator()
def setup_components(self):
"""Runs the creator functions without any distributed coordination."""
logger.debug("Loading data.")
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
if self.use_gpu and torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer")
logger.debug("Creating optimizer.")
self.optimizers = self.optimizer_creator(self.given_models,
self.config)
if not isinstance(self.optimizers, collections.Iterable):
self.optimizers = [self.optimizers]
self._create_schedulers_if_available()
self._try_setup_apex()
self._create_loss()
def setup_operator(self):
"""Create the training operator."""
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,

View file

@ -30,6 +30,12 @@ def _validate_scheduler_step_freq(scheduler_step_freq):
VALID_SCHEDULER_STEP, scheduler_step_freq))
def _remind_gpu_usage(use_gpu):
if not use_gpu and torch.cuda.is_available():
logger.info("GPUs detected but not using them. Set `use_gpu` to "
"enable GPU usage. ")
class TorchTrainer:
"""Train a PyTorch model using distributed PyTorch.
@ -69,6 +75,14 @@ class TorchTrainer:
for i in range(4):
trainer.train()
The creator functions will execute before distributed coordination and
training is setup. This is so that creator functions that download
large datasets will not trigger any timeouts.
The order of operations for creator functions are:
``data_creator`` -> ``model_creator`` -> ``optimizer_creator`` ->
``scheduler_creator`` -> ``loss_creator``.
Args:
model_creator (dict -> Model(s)): Constructor function that takes in
@ -213,6 +227,8 @@ class TorchTrainer:
if use_gpu == "auto":
use_gpu = torch.cuda.is_available()
_remind_gpu_usage(use_gpu)
if backend == "auto":
backend = "nccl" if use_gpu else "gloo"
@ -320,13 +336,32 @@ class TorchTrainer:
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
remote_setups = [
worker.setup.remote(address, i + 1, num_workers)
# Runs the creator functions.
remote_component_setup = [
worker.setup_components.remote()
for i, worker in enumerate(self.remote_workers)
]
self.local_worker.setup(address, 0, num_workers)
self.local_worker.setup_components()
# Get setup tasks in order to throw errors on failure
ray.get(remote_setups)
ray.get(remote_component_setup)
# Setup the process group among all workers.
remote_pgroup_setups = [
worker.setup_process_group.remote(address, i + 1, num_workers)
for i, worker in enumerate(self.remote_workers)
]
self.local_worker.setup_process_group(address, 0, num_workers)
# Get setup tasks in order to throw errors on failure
ray.get(remote_pgroup_setups)
# Runs code that requires all creator functions to have run.
remote_operator_setups = [
worker.setup_ddp_and_operator.remote()
for worker in self.remote_workers
]
self.local_worker.setup_ddp_and_operator()
# Get setup tasks in order to throw errors on failure
ray.get(remote_operator_setups)
def train(self,
num_steps=None,