2017-03-07 01:07:32 -08:00
|
|
|
"""ResNet training script, with some code from
|
|
|
|
https://github.com/tensorflow/models/tree/master/resnet.
|
|
|
|
"""
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2017-05-15 22:40:41 -07:00
|
|
|
import argparse
|
2017-03-07 01:07:32 -08:00
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import ray
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
import cifar_input
|
|
|
|
import resnet_model
|
|
|
|
|
2017-05-15 22:40:41 -07:00
|
|
|
# Tensorflow must be at least version 1.0.0 for the example to work.
|
2017-05-16 14:12:18 -07:00
|
|
|
if int(tf.__version__.split(".")[0]) < 1:
|
|
|
|
raise Exception("Your Tensorflow version is less than 1.0.0. Please update "
|
|
|
|
"Tensorflow to the latest version.")
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run the ResNet example.")
|
|
|
|
parser.add_argument("--dataset", default="cifar10", type=str,
|
|
|
|
help="Dataset to use: cifar10 or cifar100.")
|
|
|
|
parser.add_argument("--train_data_path",
|
|
|
|
default="cifar-10-batches-bin/data_batch*", type=str,
|
|
|
|
help="Data path for the training data.")
|
|
|
|
parser.add_argument("--eval_data_path",
|
|
|
|
default="cifar-10-batches-bin/test_batch.bin", type=str,
|
|
|
|
help="Data path for the testing data.")
|
|
|
|
parser.add_argument("--eval_dir", default="/tmp/resnet-model/eval", type=str,
|
|
|
|
help="Data path for the tensorboard logs.")
|
|
|
|
parser.add_argument("--eval_batch_count", default=50, type=int,
|
|
|
|
help="Number of batches to evaluate over.")
|
|
|
|
parser.add_argument("--num_gpus", default=0, type=int,
|
|
|
|
help="Number of GPUs to use for training.")
|
2017-05-15 22:40:41 -07:00
|
|
|
|
|
|
|
FLAGS = parser.parse_args()
|
|
|
|
|
|
|
|
# Determines if the actors require a gpu or not.
|
2017-03-07 01:07:32 -08:00
|
|
|
use_gpu = 1 if int(FLAGS.num_gpus) > 0 else 0
|
2017-03-11 15:30:31 -08:00
|
|
|
|
2017-05-16 14:12:18 -07:00
|
|
|
|
2017-05-15 22:40:41 -07:00
|
|
|
@ray.remote
|
2017-03-17 18:36:23 -07:00
|
|
|
def get_data(path, size, dataset):
|
2017-05-15 22:40:41 -07:00
|
|
|
# Retrieves all preprocessed images and labels using a tensorflow queue.
|
|
|
|
# This only uses the cpu.
|
2017-05-16 14:12:18 -07:00
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
|
|
with tf.device("/cpu:0"):
|
2017-05-15 22:40:41 -07:00
|
|
|
queue = cifar_input.build_data(path, size, dataset)
|
|
|
|
sess = tf.Session()
|
|
|
|
coord = tf.train.Coordinator()
|
|
|
|
tf.train.start_queue_runners(sess, coord=coord)
|
|
|
|
images, labels = sess.run(queue)
|
|
|
|
coord.request_stop()
|
|
|
|
sess.close()
|
|
|
|
return images, labels
|
2017-03-07 01:07:32 -08:00
|
|
|
|
2017-05-16 14:12:18 -07:00
|
|
|
|
2017-05-14 00:01:20 -07:00
|
|
|
@ray.remote(num_gpus=use_gpu)
|
2017-03-07 01:07:32 -08:00
|
|
|
class ResNetTrainActor(object):
|
2017-03-17 18:36:23 -07:00
|
|
|
def __init__(self, data, dataset, num_gpus):
|
2017-03-07 01:07:32 -08:00
|
|
|
if num_gpus > 0:
|
2017-05-16 14:12:18 -07:00
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i
|
|
|
|
in ray.get_gpu_ids()])
|
|
|
|
hps = resnet_model.HParams(
|
|
|
|
batch_size=128,
|
|
|
|
num_classes=100 if dataset == "cifar100" else 10,
|
|
|
|
min_lrn_rate=0.0001,
|
|
|
|
lrn_rate=0.1,
|
|
|
|
num_residual_units=5,
|
|
|
|
use_bottleneck=False,
|
|
|
|
weight_decay_rate=0.0002,
|
|
|
|
relu_leakiness=0.1,
|
|
|
|
optimizer="mom",
|
|
|
|
num_gpus=num_gpus)
|
|
|
|
|
|
|
|
# We seed each actor differently so that each actor operates on a different
|
|
|
|
# subset of data.
|
2017-03-17 18:36:23 -07:00
|
|
|
if num_gpus > 0:
|
|
|
|
tf.set_random_seed(ray.get_gpu_ids()[0] + 1)
|
|
|
|
else:
|
2017-05-15 22:40:41 -07:00
|
|
|
# Only a single actor in this case.
|
2017-03-17 18:36:23 -07:00
|
|
|
tf.set_random_seed(1)
|
2017-05-15 22:40:41 -07:00
|
|
|
|
|
|
|
input_images = data[0]
|
|
|
|
input_labels = data[1]
|
2017-05-16 14:12:18 -07:00
|
|
|
with tf.device("/gpu:0" if num_gpus > 0 else "/cpu:0"):
|
2017-05-15 22:40:41 -07:00
|
|
|
# Build the model.
|
2017-05-16 14:12:18 -07:00
|
|
|
images, labels = cifar_input.build_input([input_images, input_labels],
|
|
|
|
hps.batch_size, dataset, False)
|
|
|
|
self.model = resnet_model.ResNet(hps, images, labels, "train")
|
2017-03-17 18:36:23 -07:00
|
|
|
self.model.build_graph()
|
|
|
|
config = tf.ConfigProto(allow_soft_placement=True)
|
|
|
|
sess = tf.Session(config=config)
|
|
|
|
self.model.variables.set_session(sess)
|
|
|
|
self.coord = tf.train.Coordinator()
|
|
|
|
tf.train.start_queue_runners(sess, coord=self.coord)
|
|
|
|
init = tf.global_variables_initializer()
|
|
|
|
sess.run(init)
|
2017-05-15 22:40:41 -07:00
|
|
|
self.steps = 10
|
2017-03-07 01:07:32 -08:00
|
|
|
|
|
|
|
def compute_steps(self, weights):
|
2017-05-16 14:12:18 -07:00
|
|
|
# This method sets the weights in the network, trains the network
|
|
|
|
# self.steps times, and returns the new weights.
|
2017-03-07 01:07:32 -08:00
|
|
|
self.model.variables.set_weights(weights)
|
2017-05-15 22:40:41 -07:00
|
|
|
for i in range(self.steps):
|
2017-03-07 01:07:32 -08:00
|
|
|
self.model.variables.sess.run(self.model.train_op)
|
|
|
|
return self.model.variables.get_weights()
|
|
|
|
|
|
|
|
def get_weights(self):
|
2017-05-16 14:12:18 -07:00
|
|
|
# Note that the driver cannot directly access fields of the class,
|
2017-05-15 22:40:41 -07:00
|
|
|
# so helper methods must be created.
|
2017-03-07 01:07:32 -08:00
|
|
|
return self.model.variables.get_weights()
|
|
|
|
|
2017-05-16 14:12:18 -07:00
|
|
|
|
2017-05-14 00:01:20 -07:00
|
|
|
@ray.remote
|
2017-03-07 01:07:32 -08:00
|
|
|
class ResNetTestActor(object):
|
2017-03-17 18:36:23 -07:00
|
|
|
def __init__(self, data, dataset, eval_batch_count, eval_dir):
|
2017-05-16 14:12:18 -07:00
|
|
|
hps = resnet_model.HParams(
|
|
|
|
batch_size=100,
|
|
|
|
num_classes=100 if dataset == "cifar100" else 10,
|
|
|
|
min_lrn_rate=0.0001,
|
|
|
|
lrn_rate=0.1,
|
|
|
|
num_residual_units=5,
|
|
|
|
use_bottleneck=False,
|
|
|
|
weight_decay_rate=0.0002,
|
|
|
|
relu_leakiness=0.1,
|
|
|
|
optimizer="mom",
|
|
|
|
num_gpus=0)
|
2017-05-15 22:40:41 -07:00
|
|
|
input_images = data[0]
|
|
|
|
input_labels = data[1]
|
2017-05-16 14:12:18 -07:00
|
|
|
with tf.device("/cpu:0"):
|
2017-05-15 22:40:41 -07:00
|
|
|
# Builds the testing network.
|
2017-05-16 14:12:18 -07:00
|
|
|
images, labels = cifar_input.build_input([input_images, input_labels],
|
|
|
|
hps.batch_size, dataset, False)
|
|
|
|
self.model = resnet_model.ResNet(hps, images, labels, "eval")
|
2017-03-17 18:36:23 -07:00
|
|
|
self.model.build_graph()
|
|
|
|
config = tf.ConfigProto(allow_soft_placement=True)
|
|
|
|
sess = tf.Session(config=config)
|
|
|
|
self.model.variables.set_session(sess)
|
|
|
|
self.coord = tf.train.Coordinator()
|
|
|
|
tf.train.start_queue_runners(sess, coord=self.coord)
|
|
|
|
init = tf.global_variables_initializer()
|
|
|
|
sess.run(init)
|
2017-05-15 22:40:41 -07:00
|
|
|
|
|
|
|
# Initializing parameters for tensorboard.
|
2017-03-17 18:36:23 -07:00
|
|
|
self.best_precision = 0.0
|
|
|
|
self.eval_batch_count = eval_batch_count
|
|
|
|
self.summary_writer = tf.summary.FileWriter(eval_dir, sess.graph)
|
2017-05-15 22:40:41 -07:00
|
|
|
# The IP address where tensorboard logs will be on.
|
2017-03-17 18:36:23 -07:00
|
|
|
self.ip_addr = ray.services.get_node_ip_address()
|
|
|
|
|
|
|
|
def accuracy(self, weights, train_step):
|
2017-05-15 22:40:41 -07:00
|
|
|
# Sets the weights, computes the accuracy and other metrics
|
|
|
|
# over eval_batches, and outputs to tensorboard.
|
2017-03-07 01:07:32 -08:00
|
|
|
self.model.variables.set_weights(weights)
|
|
|
|
total_prediction, correct_prediction = 0, 0
|
|
|
|
model = self.model
|
|
|
|
sess = self.model.variables.sess
|
|
|
|
for _ in range(self.eval_batch_count):
|
2017-03-17 18:36:23 -07:00
|
|
|
summaries, loss, predictions, truth = sess.run(
|
|
|
|
[model.summaries, model.cost, model.predictions,
|
|
|
|
model.labels])
|
2017-03-07 01:07:32 -08:00
|
|
|
|
|
|
|
truth = np.argmax(truth, axis=1)
|
|
|
|
predictions = np.argmax(predictions, axis=1)
|
|
|
|
correct_prediction += np.sum(truth == predictions)
|
|
|
|
total_prediction += predictions.shape[0]
|
|
|
|
|
|
|
|
precision = 1.0 * correct_prediction / total_prediction
|
|
|
|
self.best_precision = max(precision, self.best_precision)
|
2017-03-17 18:36:23 -07:00
|
|
|
precision_summ = tf.Summary()
|
|
|
|
precision_summ.value.add(
|
2017-05-16 14:12:18 -07:00
|
|
|
tag="Precision", simple_value=precision)
|
2017-03-17 18:36:23 -07:00
|
|
|
self.summary_writer.add_summary(precision_summ, train_step)
|
|
|
|
best_precision_summ = tf.Summary()
|
|
|
|
best_precision_summ.value.add(
|
2017-05-16 14:12:18 -07:00
|
|
|
tag="Best Precision", simple_value=self.best_precision)
|
2017-03-17 18:36:23 -07:00
|
|
|
self.summary_writer.add_summary(best_precision_summ, train_step)
|
|
|
|
self.summary_writer.add_summary(summaries, train_step)
|
2017-05-16 14:12:18 -07:00
|
|
|
tf.logging.info("loss: %.3f, precision: %.3f, best precision: %.3f" %
|
2017-03-17 18:36:23 -07:00
|
|
|
(loss, precision, self.best_precision))
|
|
|
|
self.summary_writer.flush()
|
2017-03-07 01:07:32 -08:00
|
|
|
return precision
|
|
|
|
|
2017-03-17 18:36:23 -07:00
|
|
|
def get_ip_addr(self):
|
2017-05-16 14:12:18 -07:00
|
|
|
# As above, a helper method must be created to access the field from the
|
|
|
|
# driver.
|
2017-03-17 18:36:23 -07:00
|
|
|
return self.ip_addr
|
|
|
|
|
2017-05-16 14:12:18 -07:00
|
|
|
|
2017-03-07 01:07:32 -08:00
|
|
|
def train():
|
2017-05-15 22:40:41 -07:00
|
|
|
num_gpus = FLAGS.num_gpus
|
2017-03-17 18:36:23 -07:00
|
|
|
ray.init(num_gpus=num_gpus, redirect_output=True)
|
|
|
|
train_data = get_data.remote(FLAGS.train_data_path, 50000, FLAGS.dataset)
|
|
|
|
test_data = get_data.remote(FLAGS.eval_data_path, 10000, FLAGS.dataset)
|
2017-05-16 14:12:18 -07:00
|
|
|
# Creates an actor for each gpu, or one if only using the cpu. Each actor has
|
|
|
|
# access to the dataset.
|
2017-05-15 22:40:41 -07:00
|
|
|
if FLAGS.num_gpus > 0:
|
2017-05-16 14:12:18 -07:00
|
|
|
train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset,
|
|
|
|
num_gpus) for _ in range(num_gpus)]
|
2017-03-07 01:07:32 -08:00
|
|
|
else:
|
2017-05-15 22:40:41 -07:00
|
|
|
train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset, 0)]
|
2017-05-16 14:12:18 -07:00
|
|
|
test_actor = ResNetTestActor.remote(test_data, FLAGS.dataset,
|
|
|
|
FLAGS.eval_batch_count, FLAGS.eval_dir)
|
|
|
|
print("The log files for tensorboard are stored at ip {}."
|
|
|
|
.format(ray.get(test_actor.get_ip_addr.remote())))
|
2017-03-07 01:07:32 -08:00
|
|
|
step = 0
|
2017-05-14 00:01:20 -07:00
|
|
|
weight_id = train_actors[0].get_weights.remote()
|
|
|
|
acc_id = test_actor.accuracy.remote(weight_id, step)
|
2017-05-15 22:40:41 -07:00
|
|
|
# Correction for dividing the weights by the number of gpus.
|
2017-03-07 01:07:32 -08:00
|
|
|
if num_gpus == 0:
|
|
|
|
num_gpus = 1
|
2017-05-15 22:40:41 -07:00
|
|
|
print("Starting training loop. Use Ctrl-C to exit.")
|
|
|
|
try:
|
|
|
|
while True:
|
2017-05-16 14:12:18 -07:00
|
|
|
all_weights = ray.get([actor.compute_steps.remote(weight_id)
|
|
|
|
for actor in train_actors])
|
|
|
|
mean_weights = {k: (sum([weights[k] for weights in all_weights]) /
|
|
|
|
num_gpus)
|
|
|
|
for k in all_weights[0]}
|
2017-05-15 22:40:41 -07:00
|
|
|
weight_id = ray.put(mean_weights)
|
|
|
|
step += 10
|
|
|
|
if step % 200 == 0:
|
|
|
|
# Retrieves the previously computed accuracy and launches a new
|
|
|
|
# testing task with the current weights every 200 steps.
|
|
|
|
acc = ray.get(acc_id)
|
|
|
|
acc_id = test_actor.accuracy.remote(weight_id, step)
|
2017-05-16 14:12:18 -07:00
|
|
|
print("Step {0}: {1:.6f}".format(step - 200, acc))
|
2017-05-15 22:40:41 -07:00
|
|
|
except KeyboardInterrupt:
|
|
|
|
pass
|
2017-03-07 01:07:32 -08:00
|
|
|
|
2017-05-16 14:12:18 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2017-05-15 22:40:41 -07:00
|
|
|
train()
|