[tune] Add PyTorch MNIST Example + Misc. Tweaks (#2708)

This commit is contained in:
Richard Liaw 2018-08-30 16:18:56 -07:00 committed by GitHub
parent 224d38cbb2
commit 0347e6418b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 492 additions and 17 deletions

View file

@ -114,8 +114,7 @@ def make_parser(parser_creator=None, **kwargs):
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--checkpoint-at-end",
default=False,
type=bool,
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
@ -152,11 +151,12 @@ def to_argv(config):
for k, v in config.items():
if "-" in k:
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
argv.append("--{}".format(k.replace("_", "-")))
if not isinstance(v, bool) or v: # for argparse flags
argv.append("--{}".format(k.replace("_", "-")))
if isinstance(v, string_types):
argv.append(v)
elif isinstance(v, bool):
argv.append(v)
pass
else:
argv.append(json.dumps(v, cls=_SafeFallbackEncoder))
return argv

View file

@ -0,0 +1,191 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument(
'--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument(
'--epochs',
type=int,
default=10,
metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument(
'--lr',
type=float,
default=0.01,
metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument(
'--momentum',
type=float,
default=0.5,
metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument(
'--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--smoke-test', action="store_true", help="Finish quickly for testing")
def train_mnist(args, config, reporter):
vars(args).update(config)
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'~/data',
train=True,
download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'~/data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
if args.cuda:
model.cuda()
optimizer = optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def test():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(
output, target,
size_average=False).data[0] # sum up batch loss
pred = output.data.max(
1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
test_loss = test_loss.item() / len(test_loader.dataset)
accuracy = correct.item() / len(test_loader.dataset)
reporter(mean_loss=test_loss, mean_accuracy=accuracy)
for epoch in range(1, args.epochs + 1):
train(epoch)
test()
if __name__ == '__main__':
datasets.MNIST('~/data', train=True, download=True)
args = parser.parse_args()
import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
ray.init()
sched = AsyncHyperBandScheduler(
time_attr="training_iteration",
reward_attr="neg_mean_loss",
max_t=400,
grace_period=20)
tune.register_trainable("train_mnist",
lambda cfg, rprtr: train_mnist(args, cfg, rprtr))
tune.run_experiments(
{
"exp": {
"stop": {
"mean_accuracy": 0.98,
"training_iteration": 1 if args.smoke_test else 20
},
"trial_resources": {
"cpu": 3
},
"run": "train_mnist",
"num_samples": 1 if args.smoke_test else 10,
"config": {
"lr": lambda spec: np.random.uniform(0.001, 0.1),
"momentum": lambda spec: np.random.uniform(0.1, 0.9),
}
}
},
verbose=0,
scheduler=sched)

View file

@ -0,0 +1,203 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
from __future__ import print_function
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from ray.tune import Trainable
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument(
'--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument(
'--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument(
'--epochs',
type=int,
default=10,
metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument(
'--lr',
type=float,
default=0.01,
metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument(
'--momentum',
type=float,
default=0.5,
metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument(
'--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument(
'--smoke-test', action="store_true", help="Finish quickly for testing")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class TrainMNIST(Trainable):
def _setup(self):
args = self.config.pop("args")
vars(args).update(self.config)
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
self.train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'~/data',
train=True,
download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
self.test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
'~/data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
self.model = Net()
if args.cuda:
self.model.cuda()
self.optimizer = optim.SGD(
self.model.parameters(), lr=args.lr, momentum=args.momentum)
self.args = args
def _train_iteration(self):
self.model.train()
for batch_idx, (data, target) in enumerate(self.train_loader):
if self.args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
self.optimizer.zero_grad()
output = self.model(data)
loss = F.nll_loss(output, target)
loss.backward()
self.optimizer.step()
def _test(self):
self.model.eval()
test_loss = 0
correct = 0
for data, target in self.test_loader:
if self.args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = self.model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).data[0]
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
test_loss = test_loss.item() / len(self.test_loader.dataset)
accuracy = correct.item() / len(self.test_loader.dataset)
return {"mean_loss": test_loss, "mean_accuracy": accuracy}
def _train(self):
self._train_iteration()
return self._test()
def _save(self, path):
torch.save(self.model.state_dict(), os.path.join(path, "model.pth"))
return path
def _restore(self, path):
self.model.load_state_dict(os.path.join(path, "model.pth"))
if __name__ == '__main__':
datasets.MNIST('~/data', train=True, download=True)
args = parser.parse_args()
import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
ray.init()
sched = HyperBandScheduler(
time_attr="training_iteration", reward_attr="neg_mean_loss")
tune.run_experiments(
{
"exp": {
"stop": {
"mean_accuracy": 0.95,
"training_iteration": 1 if args.smoke_test else 20,
},
"trial_resources": {
"cpu": 3
},
"run": TrainMNIST,
"num_samples": 1 if args.smoke_test else 20,
"checkpoint_at_end": True,
"config": {
"args": args,
"lr": lambda spec: np.random.uniform(0.001, 0.1),
"momentum": lambda spec: np.random.uniform(0.1, 0.9),
}
}
},
verbose=0,
scheduler=sched)

View file

@ -115,11 +115,12 @@ class FunctionRunner(Trainable):
time.sleep(1)
result = self._status_reporter._get_and_clear_status()
curr_ts_total = result.get(TIMESTEPS_TOTAL,
self._last_reported_timestep)
result.update(
timesteps_this_iter=(curr_ts_total - self._last_reported_timestep))
self._last_reported_timestep = curr_ts_total
curr_ts_total = result.get(TIMESTEPS_TOTAL)
if curr_ts_total is not None:
result.update(
timesteps_this_iter=(
curr_ts_total - self._last_reported_timestep))
self._last_reported_timestep = curr_ts_total
return result

View file

@ -383,6 +383,36 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result['mean_accuracy'], float('inf'))
def testReportTimeStep(self):
def train(config, reporter):
for i in range(100):
reporter(mean_accuracy=5)
[trial] = run_experiments({
"foo": {
"run": train,
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL])
def train3(config, reporter):
for i in range(10):
reporter(timesteps_total=5)
[trial3] = run_experiments({
"foo": {
"run": train3,
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5)
self.assertEqual(trial3.last_result["timesteps_this_iter"], 0)
class RunExperimentTest(unittest.TestCase):
def setUp(self):
@ -505,6 +535,24 @@ class RunExperimentTest(unittest.TestCase):
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
def testCheckpointAtEnd(self):
class train(Trainable):
def _train(self):
return dict(timesteps_this_iter=1, done=True)
def _save(self, path):
return path
trials = run_experiments({
"foo": {
"run": train,
"checkpoint_at_end": True
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
class VariantGeneratorTest(unittest.TestCase):
def setUp(self):

View file

@ -74,7 +74,7 @@ class Trainable(object):
self._iteration = 0
self._time_total = 0.0
self._timesteps_total = 0
self._timesteps_total = None
self._setup()
self._initialize_ok = True
self._local_ip = ray.services.get_node_ip_address()
@ -150,9 +150,15 @@ class Trainable(object):
time_this_iter = time.time() - start
self._time_total += time_this_iter
self._timesteps_total += result.get(TIMESTEPS_THIS_ITER, 0)
result.setdefault(DONE, False)
# self._timesteps_total should only be tracked if increments provided
if result.get(TIMESTEPS_THIS_ITER):
if self._timesteps_total is None:
self._timesteps_total = 0
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
# self._timesteps_total should not override user-provided total
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
# Provides auto-filled neg_mean_loss for avoiding regressions
@ -278,12 +284,26 @@ class Trainable(object):
raise NotImplementedError
def _save(self, checkpoint_dir):
"""Subclasses should override this to implement save()."""
"""Subclasses should override this to implement save().
Args:
checkpoint_dir (str): The directory where the checkpoint
can be stored.
Returns:
Checkpoint path that may be passed to restore(). Typically
would default to `checkpoint_dir`.
"""
raise NotImplementedError
def _restore(self, checkpoint_path):
"""Subclasses should override this to implement restore()."""
"""Subclasses should override this to implement restore().
Args:
checkpoint_path (str): The directory where the checkpoint
is stored.
"""
raise NotImplementedError

View file

@ -16,7 +16,7 @@ from ray.tune.logger import pretty_print, UnifiedLogger
# have been defined yet. See https://github.com/ray-project/ray/issues/1716.
import ray.tune.registry
from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
TIME_TOTAL_S, TRAINING_ITERATION)
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL)
from ray.utils import random_string, binary_to_hex
DEBUG_PRINT_INTERVAL = 5
@ -237,8 +237,12 @@ class Trial(object):
int(self.last_result.get(TIME_TOTAL_S)))
]
if self.last_result.get("timesteps_total") is not None:
pieces.append('{} ts'.format(self.last_result["timesteps_total"]))
if self.last_result.get(TRAINING_ITERATION) is not None:
pieces.append('{} iter'.format(
self.last_result[TRAINING_ITERATION]))
if self.last_result.get(TIMESTEPS_TOTAL) is not None:
pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL]))
if self.last_result.get("episode_reward_mean") is not None:
pieces.append('{} rew'.format(

View file

@ -266,6 +266,14 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/tune_mnist_keras.py \
--smoke-test
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/mnist_pytorch.py \
--smoke-test
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \
--smoke-test
docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py