mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
180 lines
5 KiB
Python
180 lines
5 KiB
Python
# flake8: noqa
|
|
"""
|
|
This file holds code for the Torch best-practices guide in the documentation.
|
|
|
|
It ignores yapf because yapf doesn't allow comments right after code blocks,
|
|
but we put comments right after code blocks to prevent large white spaces
|
|
in the documentation.
|
|
"""
|
|
# yapf: disable
|
|
# __torch_model_start__
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
|
self.conv2 = nn.Conv2d(20, 50, 5, 1)
|
|
self.fc1 = nn.Linear(4 * 4 * 50, 500)
|
|
self.fc2 = nn.Linear(500, 10)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = F.max_pool2d(x, 2, 2)
|
|
x = F.relu(self.conv2(x))
|
|
x = F.max_pool2d(x, 2, 2)
|
|
x = x.view(-1, 4 * 4 * 50)
|
|
x = F.relu(self.fc1(x))
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x, dim=1)
|
|
# __torch_model_end__
|
|
# yapf: enable
|
|
|
|
# yapf: disable
|
|
# __torch_helper_start__
|
|
from filelock import FileLock
|
|
from torchvision import datasets, transforms
|
|
|
|
|
|
def train(model, device, train_loader, optimizer):
|
|
model.train()
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
# This break is for speeding up the tutorial.
|
|
if batch_idx * len(data) > 1024:
|
|
return
|
|
data, target = data.to(device), target.to(device)
|
|
optimizer.zero_grad()
|
|
output = model(data)
|
|
loss = F.nll_loss(output, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
|
|
def test(model, device, test_loader):
|
|
model.eval()
|
|
test_loss = 0
|
|
correct = 0
|
|
with torch.no_grad():
|
|
for data, target in test_loader:
|
|
data, target = data.to(device), target.to(device)
|
|
output = model(data)
|
|
|
|
# sum up batch loss
|
|
test_loss += F.nll_loss(
|
|
output, target, reduction="sum").item()
|
|
pred = output.argmax(
|
|
dim=1,
|
|
keepdim=True)
|
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
|
|
|
test_loss /= len(test_loader.dataset)
|
|
return {
|
|
"loss": test_loss,
|
|
"accuracy": 100. * correct / len(test_loader.dataset)
|
|
}
|
|
|
|
|
|
def dataset_creator(use_cuda):
|
|
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
|
|
with FileLock("./data.lock"):
|
|
train_loader = torch.utils.data.DataLoader(
|
|
datasets.MNIST(
|
|
"~/data",
|
|
train=True,
|
|
download=True,
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
])),
|
|
batch_size=128,
|
|
shuffle=True,
|
|
**kwargs)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
datasets.MNIST(
|
|
"~/data",
|
|
train=False,
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
])),
|
|
batch_size=128,
|
|
shuffle=True,
|
|
**kwargs)
|
|
|
|
return train_loader, test_loader
|
|
# __torch_helper_end__
|
|
# yapf: enable
|
|
|
|
# yapf: disable
|
|
# __torch_net_start__
|
|
import torch.optim as optim
|
|
|
|
|
|
class Network(object):
|
|
def __init__(self, lr=0.01, momentum=0.5):
|
|
use_cuda = torch.cuda.is_available()
|
|
self.device = device = torch.device("cuda" if use_cuda else "cpu")
|
|
self.train_loader, self.test_loader = dataset_creator(use_cuda)
|
|
|
|
self.model = Model().to(device)
|
|
self.optimizer = optim.SGD(
|
|
self.model.parameters(), lr=lr, momentum=momentum)
|
|
|
|
def train(self):
|
|
train(self.model, self.device, self.train_loader, self.optimizer)
|
|
return test(self.model, self.device, self.test_loader)
|
|
|
|
def get_weights(self):
|
|
return self.model.state_dict()
|
|
|
|
def set_weights(self, weights):
|
|
self.model.load_state_dict(weights)
|
|
|
|
def save(self):
|
|
torch.save(self.model.state_dict(), "mnist_cnn.pt")
|
|
|
|
|
|
net = Network()
|
|
net.train()
|
|
# __torch_net_end__
|
|
# yapf: enable
|
|
|
|
# yapf: disable
|
|
# __torch_ray_start__
|
|
import ray
|
|
ray.init()
|
|
|
|
RemoteNetwork = ray.remote(Network)
|
|
# Use the below instead of `ray.remote(network)` to leverage the GPU.
|
|
# RemoteNetwork = ray.remote(num_gpus=1)(Network)
|
|
# __torch_ray_end__
|
|
# yapf: enable
|
|
|
|
# yapf: disable
|
|
# __torch_actor_start__
|
|
NetworkActor = RemoteNetwork.remote()
|
|
NetworkActor2 = RemoteNetwork.remote()
|
|
|
|
ray.get([NetworkActor.train.remote(), NetworkActor2.train.remote()])
|
|
# __torch_actor_end__
|
|
# yapf: enable
|
|
|
|
# yapf: disable
|
|
# __weight_average_start__
|
|
weights = ray.get(
|
|
[NetworkActor.get_weights.remote(),
|
|
NetworkActor2.get_weights.remote()])
|
|
|
|
from collections import OrderedDict
|
|
averaged_weights = OrderedDict(
|
|
[(k, (weights[0][k] + weights[1][k]) / 2) for k in weights[0]])
|
|
|
|
weight_id = ray.put(averaged_weights)
|
|
[
|
|
actor.set_weights.remote(weight_id)
|
|
for actor in [NetworkActor, NetworkActor2]
|
|
]
|
|
ray.get([actor.train.remote() for actor in [NetworkActor, NetworkActor2]])
|