ray/doc/examples/doc_code/torch_example.py

180 lines
5 KiB
Python

# flake8: noqa
"""
This file holds code for the Torch best-practices guide in the documentation.
It ignores yapf because yapf doesn't allow comments right after code blocks,
but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# yapf: disable
# __torch_model_start__
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# __torch_model_end__
# yapf: enable
# yapf: disable
# __torch_helper_start__
from filelock import FileLock
from torchvision import datasets, transforms
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# This break is for speeding up the tutorial.
if batch_idx * len(data) > 1024:
return
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(
output, target, reduction="sum").item()
pred = output.argmax(
dim=1,
keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
return {
"loss": test_loss,
"accuracy": 100. * correct / len(test_loader.dataset)
}
def dataset_creator(use_cuda):
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
with FileLock("./data.lock"):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=128,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=128,
shuffle=True,
**kwargs)
return train_loader, test_loader
# __torch_helper_end__
# yapf: enable
# yapf: disable
# __torch_net_start__
import torch.optim as optim
class Network(object):
def __init__(self, lr=0.01, momentum=0.5):
use_cuda = torch.cuda.is_available()
self.device = device = torch.device("cuda" if use_cuda else "cpu")
self.train_loader, self.test_loader = dataset_creator(use_cuda)
self.model = Model().to(device)
self.optimizer = optim.SGD(
self.model.parameters(), lr=lr, momentum=momentum)
def train(self):
train(self.model, self.device, self.train_loader, self.optimizer)
return test(self.model, self.device, self.test_loader)
def get_weights(self):
return self.model.state_dict()
def set_weights(self, weights):
self.model.load_state_dict(weights)
def save(self):
torch.save(self.model.state_dict(), "mnist_cnn.pt")
net = Network()
net.train()
# __torch_net_end__
# yapf: enable
# yapf: disable
# __torch_ray_start__
import ray
ray.init()
RemoteNetwork = ray.remote(Network)
# Use the below instead of `ray.remote(network)` to leverage the GPU.
# RemoteNetwork = ray.remote(num_gpus=1)(Network)
# __torch_ray_end__
# yapf: enable
# yapf: disable
# __torch_actor_start__
NetworkActor = RemoteNetwork.remote()
NetworkActor2 = RemoteNetwork.remote()
ray.get([NetworkActor.train.remote(), NetworkActor2.train.remote()])
# __torch_actor_end__
# yapf: enable
# yapf: disable
# __weight_average_start__
weights = ray.get(
[NetworkActor.get_weights.remote(),
NetworkActor2.get_weights.remote()])
from collections import OrderedDict
averaged_weights = OrderedDict(
[(k, (weights[0][k] + weights[1][k]) / 2) for k in weights[0]])
weight_id = ray.put(averaged_weights)
[
actor.set_weights.remote(weight_id)
for actor in [NetworkActor, NetworkActor2]
]
ray.get([actor.train.remote() for actor in [NetworkActor, NetworkActor2]])