mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[sgd] Semantic Segmentation Example (#7825)
* better_example * test * improve some usability things * submit * fix * making a segmentation example * segmentation_example * segmentation * device * flake * Update python/ray/util/sgd/torch/training_operator.py * uti * finished_example * block * format * locationg * fix * ok * revert * segmentation * lint_and_test * address_comments
This commit is contained in:
parent
0b4e09da76
commit
dd63178e91
13 changed files with 783 additions and 71 deletions
|
@ -610,8 +610,8 @@ You can see more details in the `benchmarking README <https://github.com/ray-pro
|
|||
DISCLAIMER: RaySGD does not provide any custom communication primitives. If you see any performance issues, you may need to file them on the PyTorch github repository.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
Debugging/Tips
|
||||
--------------
|
||||
|
||||
Here's some simple tips on how to debug the TorchTrainer.
|
||||
|
||||
|
@ -657,6 +657,32 @@ Try using a profiler. Either use:
|
|||
|
||||
or use `Python profiling <https://docs.python.org/3/library/debug.html>`_.
|
||||
|
||||
**My creator functions download data, and I don't want multiple processes downloading to the same path at once.**
|
||||
|
||||
Use ``filelock`` within the creator functions to create locks for critical regions. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from filelock import FileLock
|
||||
|
||||
def create_dataset(config):
|
||||
dataset_path = config["dataset_path"]
|
||||
|
||||
# Create a critical region of the code
|
||||
# This will take a longer amount of time to download the data at first.
|
||||
# Other processes will block at the ``with`` statement.
|
||||
# After downloading, this code block becomes very fast.
|
||||
with FileLock(os.path.join(tempfile.gettempdir(), "download_data.lock")):
|
||||
if not os.path.exists(dataset_path):
|
||||
download_data(dataset_path)
|
||||
|
||||
# load_data is assumed to safely support concurrent reads.
|
||||
data = load_data(dataset_path)
|
||||
return DataLoader(data)
|
||||
|
||||
|
||||
**I get a 'socket timeout' error during training.**
|
||||
|
||||
Try increasing the length of the NCCL timeout. The current timeout is 10 seconds.
|
||||
|
@ -688,6 +714,9 @@ to contribute an example, feel free to create a `pull request here <https://gith
|
|||
- `TorchTrainer and RayTune example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/tune_example.py>`__:
|
||||
Simple example of hyperparameter tuning with Ray's TorchTrainer.
|
||||
|
||||
- `Semantic Segmentation example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/segmentation/train_segmentation.py>`__:
|
||||
Fine-tuning a ResNet50 model on VOC with Batch Norm.
|
||||
|
||||
- `CIFAR10 example <https://github.com/ray-project/ray/blob/master/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py>`__:
|
||||
Training a ResNet18 model on CIFAR10.
|
||||
|
||||
|
|
|
@ -47,6 +47,19 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
|
|||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=1)
|
||||
trainer.train(num_steps=1)
|
||||
trainer.max_replicas = 2
|
||||
results = trainer.train(num_steps=1, reduce_results=False)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from datetime import timedelta
|
||||
import collections
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -12,7 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner, _remind_gpu_usage
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -44,25 +43,23 @@ class DistributedTorchRunner(TorchRunner):
|
|||
self.backend = backend
|
||||
self.wrap_ddp = wrap_ddp
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
self.world_rank = None
|
||||
|
||||
def setup(self, url, world_rank, world_size):
|
||||
"""Connects to the distributed PyTorch backend and initializes the model.
|
||||
def setup(self):
|
||||
raise RuntimeError("Need to call setup commands separately.")
|
||||
|
||||
def setup_process_group(self, url, world_rank, world_size):
|
||||
"""Connects the distributed PyTorch backend.
|
||||
|
||||
Args:
|
||||
url (str): the URL used to connect to distributed PyTorch.
|
||||
world_rank (int): the index of the runner.
|
||||
world_size (int): the total number of runners.
|
||||
"""
|
||||
_remind_gpu_usage(self.use_gpu, is_chief=world_rank == 0)
|
||||
self._setup_distributed_pytorch(url, world_rank, world_size)
|
||||
self._setup_training()
|
||||
|
||||
def _setup_distributed_pytorch(self, url, world_rank, world_size):
|
||||
self.world_rank = world_rank
|
||||
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
|
||||
url, world_rank, world_size))
|
||||
logger.debug("using {}".format(self.backend))
|
||||
|
||||
if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
|
||||
logger.debug(
|
||||
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
|
||||
|
@ -77,48 +74,26 @@ class DistributedTorchRunner(TorchRunner):
|
|||
world_size=world_size,
|
||||
timeout=timeout)
|
||||
|
||||
self.device_ids = None
|
||||
def setup_ddp_and_operator(self):
|
||||
"""Runs distributed coordination components.
|
||||
|
||||
This helps avoid timeouts due to creator functions (perhaps
|
||||
downloading data or models).
|
||||
"""
|
||||
device_ids = None
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
# https://github.com/allenai/allennlp/issues/1090
|
||||
self.set_cuda_device_id()
|
||||
device_ids = self.get_device_ids()
|
||||
|
||||
def set_cuda_device_id(self):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
self.device_ids = [0]
|
||||
|
||||
def _setup_training(self):
|
||||
logger.debug("Loading data.")
|
||||
self._initialize_dataloaders()
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
self.models = [self.models]
|
||||
assert all(isinstance(model, nn.Module) for model in self.models), (
|
||||
"All models must be PyTorch models: {}.".format(self.models))
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer.")
|
||||
self.optimizers = self.optimizer_creator(self.given_models,
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
|
||||
self._create_loss()
|
||||
# Wrap dataloaders
|
||||
self._wrap_dataloaders()
|
||||
|
||||
training_models = self.models
|
||||
|
||||
if self.wrap_ddp:
|
||||
# This needs to happen after apex
|
||||
training_models = [
|
||||
DistributedDataParallel(model, device_ids=self.device_ids)
|
||||
DistributedDataParallel(model, device_ids=device_ids)
|
||||
for model in self.models
|
||||
]
|
||||
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=training_models,
|
||||
|
@ -128,14 +103,33 @@ class DistributedTorchRunner(TorchRunner):
|
|||
validation_loader=self.validation_loader,
|
||||
world_rank=self.world_rank,
|
||||
schedulers=self.schedulers,
|
||||
device_ids=self.device_ids,
|
||||
device_ids=device_ids,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
|
||||
def _initialize_dataloaders(self):
|
||||
super(DistributedTorchRunner, self)._initialize_dataloaders()
|
||||
def get_device_ids(self):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
return [0]
|
||||
|
||||
def load_state_stream(self, byte_obj):
|
||||
"""Loads a bytes object the training state dict.
|
||||
|
||||
This is needed because we don't want to deserialize the tensor
|
||||
onto the same device (which is from the driver process). We want to
|
||||
map it onto the actor's specific device.
|
||||
|
||||
From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769
|
||||
"""
|
||||
_buffer = io.BytesIO(byte_obj)
|
||||
to_gpu = self.use_gpu and torch.cuda.is_available()
|
||||
state_dict = torch.load(
|
||||
_buffer,
|
||||
map_location=("cpu" if not to_gpu else
|
||||
lambda storage, loc: storage.cuda()))
|
||||
return self.load_state_dict(state_dict)
|
||||
|
||||
def _wrap_dataloaders(self):
|
||||
def with_sampler(loader):
|
||||
# Automatically set the DistributedSampler
|
||||
data_loader_args = {
|
||||
|
@ -215,11 +209,12 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
|||
# TODO: we should make sure this NEVER dies.
|
||||
self.local_device = "0"
|
||||
global _dummy_actor
|
||||
if not self.is_actor() and _dummy_actor is None:
|
||||
_dummy_actor = ray.remote(
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources={"node:" + ip: 0.1})(_DummyActor).remote()
|
||||
if not self.is_actor():
|
||||
if _dummy_actor is None:
|
||||
_dummy_actor = ray.remote(
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources={"node:" + ip: 0.1})(_DummyActor).remote()
|
||||
|
||||
self.local_device = ray.get(_dummy_actor.cuda_devices.remote())
|
||||
|
||||
|
@ -235,16 +230,18 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
|||
# interaction.
|
||||
_init_cuda_context()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device
|
||||
|
||||
if self.local_device:
|
||||
try:
|
||||
torch.cuda.set_device(int(self.local_device))
|
||||
except RuntimeError:
|
||||
logger.error("This happens if cuda is not initialized.")
|
||||
raise
|
||||
|
||||
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
def set_cuda_device_id(self):
|
||||
self.device_ids = [int(self.local_device)]
|
||||
def get_device_ids(self):
|
||||
return [int(self.local_device)]
|
||||
|
||||
def shutdown(self, cleanup=True):
|
||||
super(LocalDistributedRunner, self).shutdown()
|
||||
|
|
|
@ -73,7 +73,7 @@ if __name__ == "__main__":
|
|||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
help="the address to use for connecting to the Ray cluster")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
|
|
34
python/ray/util/sgd/torch/examples/segmentation/README.md
Normal file
34
python/ray/util/sgd/torch/examples/segmentation/README.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Semantic segmentation reference -> RaySGD
|
||||
|
||||
Original scripts are taken from: https://github.com/pytorch/vision/tree/master/references/segmentation.
|
||||
|
||||
On a single node, you can leverage Distributed Data Parallelism (DDP) by simply using the `-n` parameter. This will automatically parallelize your training across `n` GPUs. As listed from the original repository, below are standard hyperparameters.
|
||||
|
||||
```bash
|
||||
pip install tqdm pycocotools
|
||||
```
|
||||
|
||||
|
||||
## fcn_resnet101
|
||||
```
|
||||
python train_segmentation.py -n 4 --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss
|
||||
```
|
||||
|
||||
## deeplabv3_resnet101
|
||||
```
|
||||
python train_segmentation.py train_segmentation.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
|
||||
```
|
||||
|
||||
## Scaling up
|
||||
|
||||
This example can be executed on AWS by running
|
||||
```
|
||||
ray submit cluster.yaml train_segmentation.py -- start --args="--lr 0.02 ..."
|
||||
```
|
||||
|
||||
To leverage multiple GPUs (beyond a single node), be sure to add an `address` parameter:
|
||||
|
||||
```
|
||||
ray submit cluster.yaml train_segmentation.py -- start --args="--address='auto' --lr 0.02 ..."
|
||||
```
|
||||
|
116
python/ray/util/sgd/torch/examples/segmentation/coco_utils.py
Normal file
116
python/ray/util/sgd/torch/examples/segmentation/coco_utils.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# flake8: noqa
|
||||
import copy
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
import os
|
||||
|
||||
from pycocotools import mask as coco_mask
|
||||
|
||||
from ray.util.sgd.torch.examples.segmentation.transforms import Compose
|
||||
|
||||
|
||||
class FilterAndRemapCocoCategories(object):
|
||||
def __init__(self, categories, remap=True):
|
||||
self.categories = categories
|
||||
self.remap = remap
|
||||
|
||||
def __call__(self, image, anno):
|
||||
anno = [obj for obj in anno if obj["category_id"] in self.categories]
|
||||
if not self.remap:
|
||||
return image, anno
|
||||
anno = copy.deepcopy(anno)
|
||||
for obj in anno:
|
||||
obj["category_id"] = self.categories.index(obj["category_id"])
|
||||
return image, anno
|
||||
|
||||
|
||||
def convert_coco_poly_to_mask(segmentations, height, width):
|
||||
masks = []
|
||||
for polygons in segmentations:
|
||||
rles = coco_mask.frPyObjects(polygons, height, width)
|
||||
mask = coco_mask.decode(rles)
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask[..., None]
|
||||
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
||||
mask = mask.any(dim=2)
|
||||
masks.append(mask)
|
||||
if masks:
|
||||
masks = torch.stack(masks, dim=0)
|
||||
else:
|
||||
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
||||
return masks
|
||||
|
||||
|
||||
class ConvertCocoPolysToMask(object):
|
||||
def __call__(self, image, anno):
|
||||
w, h = image.size
|
||||
segmentations = [obj["segmentation"] for obj in anno]
|
||||
cats = [obj["category_id"] for obj in anno]
|
||||
if segmentations:
|
||||
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
||||
cats = torch.as_tensor(cats, dtype=masks.dtype)
|
||||
# merge all instance masks into a single segmentation map
|
||||
# with its corresponding categories
|
||||
target, _ = (masks * cats[:, None, None]).max(dim=0)
|
||||
# discard overlapping instances
|
||||
target[masks.sum(0) > 1] = 255
|
||||
else:
|
||||
target = torch.zeros((h, w), dtype=torch.uint8)
|
||||
target = Image.fromarray(target.numpy())
|
||||
return image, target
|
||||
|
||||
|
||||
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
||||
def _has_valid_annotation(anno):
|
||||
# if it's empty, there is no annotation
|
||||
if len(anno) == 0:
|
||||
return False
|
||||
# if more than 1k pixels occupied in the image
|
||||
return sum(obj["area"] for obj in anno) > 1000
|
||||
|
||||
assert isinstance(dataset, torchvision.datasets.CocoDetection)
|
||||
ids = []
|
||||
for ds_idx, img_id in enumerate(dataset.ids):
|
||||
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = dataset.coco.loadAnns(ann_ids)
|
||||
if cat_list:
|
||||
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
||||
if _has_valid_annotation(anno):
|
||||
ids.append(ds_idx)
|
||||
|
||||
dataset = torch.utils.data.Subset(dataset, ids)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_coco(root, image_set, transforms):
|
||||
PATHS = {
|
||||
"train": ("train2017",
|
||||
os.path.join("annotations", "instances_train2017.json")),
|
||||
"val": ("val2017", os.path.join("annotations",
|
||||
"instances_val2017.json")),
|
||||
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
|
||||
}
|
||||
CAT_LIST = [
|
||||
0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7,
|
||||
72
|
||||
]
|
||||
|
||||
transforms = Compose([
|
||||
FilterAndRemapCocoCategories(CAT_LIST, remap=True),
|
||||
ConvertCocoPolysToMask(), transforms
|
||||
])
|
||||
|
||||
img_folder, ann_file = PATHS[image_set]
|
||||
img_folder = os.path.join(root, img_folder)
|
||||
ann_file = os.path.join(root, ann_file)
|
||||
|
||||
dataset = torchvision.datasets.CocoDetection(
|
||||
img_folder, ann_file, transforms=transforms)
|
||||
|
||||
if image_set == "train":
|
||||
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
|
||||
|
||||
return dataset
|
61
python/ray/util/sgd/torch/examples/segmentation/example.yaml
Normal file
61
python/ray/util/sgd/torch/examples/segmentation/example.yaml
Normal file
|
@ -0,0 +1,61 @@
|
|||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: sgd-coco-pytorch
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 1
|
||||
initial_workers: 1
|
||||
max_workers: 1
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-east-1
|
||||
availability_zone: us-east-1c
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
|
||||
head_node:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0698bcaf8bd9ef56d
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 300
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0698bcaf8bd9ef56d
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 300
|
||||
|
||||
setup_commands:
|
||||
# This replaces the standard anaconda Ray installation
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
|
||||
# Uncomment this and the filemount to update the Ray installation with your local Ray code
|
||||
# - rm -rf ./anaconda3/lib/python3.6/site-packages/ray/util/sgd/
|
||||
# - cp -rf ~/sgd ./anaconda3/lib/python3.6/site-packages/ray/util/
|
||||
- pip install -q tqdm
|
||||
|
||||
# Installing this without -U to make sure we don't replace the existing Ray installation
|
||||
- pip install ray[rllib]
|
||||
- pip install -U ipdb torch torchvision pycocotools
|
||||
# Install Apex if needed.
|
||||
- git clone https://github.com/NVIDIA/apex; cd apex && pip install -v --no-cache-dir ./ || true
|
||||
|
||||
|
||||
file_mounts: {
|
||||
# # This should point to ray/python/ray/util/sgd.
|
||||
# ~/sgd: ../../../../sgd,
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
import datetime
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch import nn
|
||||
import torchvision
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.examples.segmentation.coco_utils import get_coco
|
||||
import ray.util.sgd.torch.examples.segmentation.transforms as T
|
||||
import ray.util.sgd.torch.examples.segmentation.utils as utils
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
|
||||
def get_dataset(name,
|
||||
image_set,
|
||||
transform,
|
||||
num_classes_only=False,
|
||||
download="auto"):
|
||||
def sbd(*args, **kwargs):
|
||||
return torchvision.datasets.SBDataset(
|
||||
*args, mode="segmentation", **kwargs)
|
||||
|
||||
paths = {
|
||||
"voc": (os.path.expanduser("~/datasets01/VOC/060817/"),
|
||||
torchvision.datasets.VOCSegmentation, 21),
|
||||
"voc_aug": (os.path.expanduser("~/datasets01/SBDD/072318/"), sbd, 21),
|
||||
"coco": (os.path.expanduser("~/datasets01/COCO/022719/"), get_coco, 21)
|
||||
}
|
||||
p, ds_fn, num_classes = paths[name]
|
||||
if num_classes_only:
|
||||
return None, num_classes
|
||||
|
||||
if download == "auto" and os.path.exists(p):
|
||||
download = False
|
||||
try:
|
||||
ds = ds_fn(
|
||||
p, download=download, image_set=image_set, transforms=transform)
|
||||
except RuntimeError:
|
||||
print("data loading failed. Retrying this.")
|
||||
ds = ds_fn(p, download=True, image_set=image_set, transforms=transform)
|
||||
return ds, num_classes
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
# Within a machine, this code runs synchronously.
|
||||
dataset, num_classes = get_dataset(
|
||||
args.dataset, "train", get_transform(train=True))
|
||||
config["num_classes"] = num_classes
|
||||
dataset_test, _ = get_dataset(
|
||||
args.dataset, "val", get_transform(train=False))
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.data_workers,
|
||||
collate_fn=utils.collate_fn,
|
||||
drop_last=True)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test,
|
||||
batch_size=1,
|
||||
num_workers=args.data_workers,
|
||||
collate_fn=utils.collate_fn)
|
||||
return data_loader, data_loader_test
|
||||
|
||||
|
||||
def get_transform(train):
|
||||
base_size = 520
|
||||
crop_size = 480
|
||||
|
||||
min_size = int((0.5 if train else 1.0) * base_size)
|
||||
max_size = int((2.0 if train else 1.0) * base_size)
|
||||
transforms = []
|
||||
transforms.append(T.RandomResize(min_size, max_size))
|
||||
if train:
|
||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
||||
transforms.append(T.RandomCrop(crop_size))
|
||||
transforms.append(T.ToTensor())
|
||||
transforms.append(
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
|
||||
|
||||
return T.Compose(transforms)
|
||||
|
||||
|
||||
def criterion(inputs, target):
|
||||
losses = {}
|
||||
for name, x in inputs.items():
|
||||
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
|
||||
|
||||
if len(losses) == 1:
|
||||
return losses["out"]
|
||||
|
||||
return losses["out"] + 0.5 * losses["aux"]
|
||||
|
||||
|
||||
class SegOperator(TrainingOperator):
|
||||
def train_batch(self, batch, batch_info):
|
||||
image, target = batch
|
||||
image, target = image.to(self.device), target.to(self.device)
|
||||
output = self.model(image)
|
||||
loss = criterion(output, target)
|
||||
self.optimizer.zero_grad()
|
||||
if self.use_fp16 and amp:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
lr = self.optimizer.param_groups[0]["lr"]
|
||||
return {"loss": loss.item(), "lr": lr, "num_samples": len(batch)}
|
||||
|
||||
def validate(self, data_loader, info=None):
|
||||
self.model.eval()
|
||||
confmat = utils.ConfusionMatrix(self.config["num_classes"])
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
image, target = image.to(self.device), target.to(self.device)
|
||||
output = self.model(image)
|
||||
output = output["out"]
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
confmat.reduce_from_all_processes()
|
||||
return confmat
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
args = config["args"]
|
||||
params_to_optimize = [
|
||||
{
|
||||
"params": [
|
||||
p for p in model.backbone.parameters() if p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for p in model.classifier.parameters() if p.requires_grad
|
||||
]
|
||||
},
|
||||
]
|
||||
if args.aux_loss:
|
||||
params = [
|
||||
p for p in model.aux_classifier.parameters() if p.requires_grad
|
||||
]
|
||||
params_to_optimize.append({"params": params, "lr": args.lr * 10})
|
||||
return torch.optim.SGD(
|
||||
params_to_optimize,
|
||||
lr=args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
args = config["args"]
|
||||
model = torchvision.models.segmentation.__dict__[args.model](
|
||||
num_classes=config["num_classes"],
|
||||
aux_loss=args.aux_loss,
|
||||
pretrained=args.pretrained)
|
||||
if config["num_workers"] > 1:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
print(args)
|
||||
start_time = time.time()
|
||||
config = {"args": args, "num_workers": args.num_workers}
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
training_operator_cls=SegOperator,
|
||||
use_tqdm=True,
|
||||
use_fp16=True,
|
||||
num_workers=config["num_workers"],
|
||||
config=config,
|
||||
use_gpu=torch.cuda.is_available())
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
trainer.train()
|
||||
confmat = trainer.validate(reduce_results=False)[0]
|
||||
print(confmat)
|
||||
state_dict = trainer.state_dict()
|
||||
state_dict.update(epoch=epoch, args=args)
|
||||
torch.save(state_dict,
|
||||
os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("Training time {}".format(total_time_str))
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch Segmentation Training with RaySGD")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
default=None,
|
||||
help="the address to use for connecting to a Ray cluster.")
|
||||
parser.add_argument("--dataset", default="voc", help="dataset")
|
||||
parser.add_argument("--model", default="fcn_resnet101", help="model")
|
||||
parser.add_argument(
|
||||
"--aux-loss", action="store_true", help="auxiliar loss")
|
||||
parser.add_argument("--device", default="cuda", help="device")
|
||||
parser.add_argument("-b", "--batch-size", default=8, type=int)
|
||||
parser.add_argument(
|
||||
"-n", "--num-workers", default=1, type=int, help="GPU parallelism")
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
default=30,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of total epochs to run")
|
||||
parser.add_argument(
|
||||
"--data-workers",
|
||||
default=16,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of data loading workers (default: 16)")
|
||||
parser.add_argument(
|
||||
"--lr", default=0.01, type=float, help="initial learning rate")
|
||||
parser.add_argument(
|
||||
"--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
||||
parser.add_argument(
|
||||
"--wd",
|
||||
"--weight-decay",
|
||||
default=1e-4,
|
||||
type=float,
|
||||
metavar="W",
|
||||
help="weight decay (default: 1e-4)",
|
||||
dest="weight_decay")
|
||||
parser.add_argument("--output-dir", default=".", help="path where to save")
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
dest="pretrained",
|
||||
help="Use pre-trained models from the modelzoo",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
ray.init(address=args.address)
|
||||
main(args)
|
|
@ -0,0 +1,93 @@
|
|||
# flake8: noqa
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torchvision import transforms as T
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
def pad_if_smaller(img, size, fill=0):
|
||||
min_size = min(img.size)
|
||||
if min_size < size:
|
||||
ow, oh = img.size
|
||||
padh = size - oh if oh < size else 0
|
||||
padw = size - ow if ow < size else 0
|
||||
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
||||
return img
|
||||
|
||||
|
||||
class Compose(object):
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomResize(object):
|
||||
def __init__(self, min_size, max_size=None):
|
||||
self.min_size = min_size
|
||||
if max_size is None:
|
||||
max_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, image, target):
|
||||
size = random.randint(self.min_size, self.max_size)
|
||||
image = F.resize(image, size)
|
||||
target = F.resize(target, size, interpolation=Image.NEAREST)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
def __init__(self, flip_prob):
|
||||
self.flip_prob = flip_prob
|
||||
|
||||
def __call__(self, image, target):
|
||||
if random.random() < self.flip_prob:
|
||||
image = F.hflip(image)
|
||||
target = F.hflip(target)
|
||||
return image, target
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = pad_if_smaller(image, self.size)
|
||||
target = pad_if_smaller(target, self.size, fill=255)
|
||||
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
||||
image = F.crop(image, *crop_params)
|
||||
target = F.crop(target, *crop_params)
|
||||
return image, target
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = F.center_crop(image, self.size)
|
||||
target = F.center_crop(target, self.size)
|
||||
return image, target
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
def __call__(self, image, target):
|
||||
image = F.to_tensor(image)
|
||||
target = torch.as_tensor(np.asarray(target), dtype=torch.int64)
|
||||
return image, target
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target):
|
||||
image = F.normalize(image, mean=self.mean, std=self.std)
|
||||
return image, target
|
70
python/ray/util/sgd/torch/examples/segmentation/utils.py
Normal file
70
python/ray/util/sgd/torch/examples/segmentation/utils.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
# flake8: noqa
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import errno
|
||||
import os
|
||||
|
||||
|
||||
class ConfusionMatrix(object):
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.no_grad():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||
return acc_global, acc, iu
|
||||
|
||||
def reduce_from_all_processes(self):
|
||||
if not torch.distributed.is_available():
|
||||
return
|
||||
if not torch.distributed.is_initialized():
|
||||
return
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.all_reduce(self.mat)
|
||||
|
||||
def __str__(self):
|
||||
acc_global, acc, iu = self.compute()
|
||||
return ('global correct: {:.1f}\n'
|
||||
'average row correct: {}\n'
|
||||
'IoU: {}\n'
|
||||
'mean IoU: {:.1f}').format(
|
||||
acc_global.item() * 100,
|
||||
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
|
||||
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100)
|
||||
|
||||
|
||||
def cat_list(images, fill_value=0):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
batch_shape = (len(images), ) + max_size
|
||||
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||
for img, pad_img in zip(images, batched_imgs):
|
||||
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
|
||||
return batched_imgs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
images, targets = list(zip(*batch))
|
||||
batched_imgs = cat_list(images, fill_value=0)
|
||||
batched_targets = cat_list(targets, fill_value=255)
|
||||
return batched_imgs, batched_targets
|
|
@ -7,6 +7,7 @@ import itertools
|
|||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS
|
||||
|
@ -23,14 +24,6 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
def _remind_gpu_usage(use_gpu, is_chief):
|
||||
if not is_chief:
|
||||
return
|
||||
if not use_gpu and torch.cuda.is_available():
|
||||
logger.info("GPUs detected but not using them. Set `use_gpu` to "
|
||||
"enable GPU usage. ")
|
||||
|
||||
|
||||
class TorchRunner:
|
||||
"""Manages a PyTorch model for training.
|
||||
|
||||
|
@ -79,6 +72,7 @@ class TorchRunner:
|
|||
self.schedulers = None
|
||||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
self.training_operator = None
|
||||
self.use_gpu = use_gpu
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
|
@ -149,24 +143,35 @@ class TorchRunner:
|
|||
self.models, self.optimizers, **self.apex_args)
|
||||
|
||||
def setup(self):
|
||||
"""Initializes the model."""
|
||||
_remind_gpu_usage(self.use_gpu, is_chief=True)
|
||||
"""Merges setup_components and setup_operator in one call."""
|
||||
self.setup_components()
|
||||
self.setup_operator()
|
||||
|
||||
def setup_components(self):
|
||||
"""Runs the creator functions without any distributed coordination."""
|
||||
logger.debug("Loading data.")
|
||||
self._initialize_dataloaders()
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
self.models = [self.models]
|
||||
assert all(isinstance(model, nn.Module) for model in self.models), (
|
||||
"All models must be PyTorch models: {}.".format(self.models))
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer")
|
||||
logger.debug("Creating optimizer.")
|
||||
self.optimizers = self.optimizer_creator(self.given_models,
|
||||
self.config)
|
||||
if not isinstance(self.optimizers, collections.Iterable):
|
||||
self.optimizers = [self.optimizers]
|
||||
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
self._create_loss()
|
||||
|
||||
def setup_operator(self):
|
||||
"""Create the training operator."""
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
|
|
|
@ -30,6 +30,12 @@ def _validate_scheduler_step_freq(scheduler_step_freq):
|
|||
VALID_SCHEDULER_STEP, scheduler_step_freq))
|
||||
|
||||
|
||||
def _remind_gpu_usage(use_gpu):
|
||||
if not use_gpu and torch.cuda.is_available():
|
||||
logger.info("GPUs detected but not using them. Set `use_gpu` to "
|
||||
"enable GPU usage. ")
|
||||
|
||||
|
||||
class TorchTrainer:
|
||||
"""Train a PyTorch model using distributed PyTorch.
|
||||
|
||||
|
@ -69,6 +75,14 @@ class TorchTrainer:
|
|||
for i in range(4):
|
||||
trainer.train()
|
||||
|
||||
The creator functions will execute before distributed coordination and
|
||||
training is setup. This is so that creator functions that download
|
||||
large datasets will not trigger any timeouts.
|
||||
|
||||
The order of operations for creator functions are:
|
||||
|
||||
``data_creator`` -> ``model_creator`` -> ``optimizer_creator`` ->
|
||||
``scheduler_creator`` -> ``loss_creator``.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> Model(s)): Constructor function that takes in
|
||||
|
@ -213,6 +227,8 @@ class TorchTrainer:
|
|||
if use_gpu == "auto":
|
||||
use_gpu = torch.cuda.is_available()
|
||||
|
||||
_remind_gpu_usage(use_gpu)
|
||||
|
||||
if backend == "auto":
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
|
@ -320,13 +336,32 @@ class TorchTrainer:
|
|||
|
||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
|
||||
remote_setups = [
|
||||
worker.setup.remote(address, i + 1, num_workers)
|
||||
# Runs the creator functions.
|
||||
remote_component_setup = [
|
||||
worker.setup_components.remote()
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
self.local_worker.setup(address, 0, num_workers)
|
||||
self.local_worker.setup_components()
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_setups)
|
||||
ray.get(remote_component_setup)
|
||||
|
||||
# Setup the process group among all workers.
|
||||
remote_pgroup_setups = [
|
||||
worker.setup_process_group.remote(address, i + 1, num_workers)
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
self.local_worker.setup_process_group(address, 0, num_workers)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_pgroup_setups)
|
||||
|
||||
# Runs code that requires all creator functions to have run.
|
||||
remote_operator_setups = [
|
||||
worker.setup_ddp_and_operator.remote()
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.setup_ddp_and_operator()
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_operator_setups)
|
||||
|
||||
def train(self,
|
||||
num_steps=None,
|
||||
|
|
Loading…
Add table
Reference in a new issue