mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
317 lines
12 KiB
Python
317 lines
12 KiB
Python
"""ResNet model with most of the code taken from
|
|
https://github.com/tensorflow/models/tree/master/resnet.
|
|
|
|
Related papers:
|
|
https://arxiv.org/pdf/1603.05027v2.pdf
|
|
https://arxiv.org/pdf/1512.03385v1.pdf
|
|
https://arxiv.org/pdf/1605.07146v1.pdf
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from collections import namedtuple
|
|
import numpy as np
|
|
|
|
import tensorflow as tf
|
|
from tensorflow.python.training import moving_averages
|
|
|
|
import ray
|
|
import ray.experimental.tf_utils
|
|
|
|
HParams = namedtuple(
|
|
"HParams", "batch_size, num_classes, min_lrn_rate, lrn_rate, "
|
|
"num_residual_units, use_bottleneck, weight_decay_rate, "
|
|
"relu_leakiness, optimizer, num_gpus")
|
|
|
|
|
|
class ResNet(object):
|
|
"""ResNet model."""
|
|
|
|
def __init__(self, hps, images, labels, mode):
|
|
"""ResNet constructor.
|
|
|
|
Args:
|
|
hps: Hyperparameters.
|
|
images: Batches of images of size [batch_size, image_size,
|
|
image_size, 3].
|
|
labels: Batches of labels of size [batch_size, num_classes].
|
|
mode: One of 'train' and 'eval'.
|
|
"""
|
|
self.hps = hps
|
|
self._images = images
|
|
self.labels = labels
|
|
self.mode = mode
|
|
|
|
self._extra_train_ops = []
|
|
|
|
def build_graph(self):
|
|
"""Build a whole graph for the model."""
|
|
self.global_step = tf.Variable(0, trainable=False)
|
|
self._build_model()
|
|
if self.mode == "train":
|
|
self._build_train_op()
|
|
else:
|
|
# Additional initialization for the test network.
|
|
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
|
self.cost)
|
|
self.summaries = tf.summary.merge_all()
|
|
|
|
def _stride_arr(self, stride):
|
|
"""Map a stride scalar to the stride array for tf.nn.conv2d."""
|
|
return [1, stride, stride, 1]
|
|
|
|
def _build_model(self):
|
|
"""Build the core model within the graph."""
|
|
|
|
with tf.variable_scope("init"):
|
|
x = self._conv("init_conv", self._images, 3, 3, 16,
|
|
self._stride_arr(1))
|
|
|
|
strides = [1, 2, 2]
|
|
activate_before_residual = [True, False, False]
|
|
if self.hps.use_bottleneck:
|
|
res_func = self._bottleneck_residual
|
|
filters = [16, 64, 128, 256]
|
|
else:
|
|
res_func = self._residual
|
|
filters = [16, 16, 32, 64]
|
|
|
|
with tf.variable_scope("unit_1_0"):
|
|
x = res_func(x, filters[0], filters[1], self._stride_arr(
|
|
strides[0]), activate_before_residual[0])
|
|
for i in range(1, self.hps.num_residual_units):
|
|
with tf.variable_scope("unit_1_%d" % i):
|
|
x = res_func(x, filters[1], filters[1], self._stride_arr(1),
|
|
False)
|
|
|
|
with tf.variable_scope("unit_2_0"):
|
|
x = res_func(x, filters[1], filters[2], self._stride_arr(
|
|
strides[1]), activate_before_residual[1])
|
|
for i in range(1, self.hps.num_residual_units):
|
|
with tf.variable_scope("unit_2_%d" % i):
|
|
x = res_func(x, filters[2], filters[2], self._stride_arr(1),
|
|
False)
|
|
|
|
with tf.variable_scope("unit_3_0"):
|
|
x = res_func(x, filters[2], filters[3], self._stride_arr(
|
|
strides[2]), activate_before_residual[2])
|
|
for i in range(1, self.hps.num_residual_units):
|
|
with tf.variable_scope("unit_3_%d" % i):
|
|
x = res_func(x, filters[3], filters[3], self._stride_arr(1),
|
|
False)
|
|
with tf.variable_scope("unit_last"):
|
|
x = self._batch_norm("final_bn", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
x = self._global_avg_pool(x)
|
|
|
|
with tf.variable_scope("logit"):
|
|
logits = self._fully_connected(x, self.hps.num_classes)
|
|
self.predictions = tf.nn.softmax(logits)
|
|
|
|
with tf.variable_scope("costs"):
|
|
xent = tf.nn.softmax_cross_entropy_with_logits(
|
|
logits=logits, labels=self.labels)
|
|
self.cost = tf.reduce_mean(xent, name="xent")
|
|
self.cost += self._decay()
|
|
|
|
if self.mode == "eval":
|
|
tf.summary.scalar("cost", self.cost)
|
|
|
|
def _build_train_op(self):
|
|
"""Build training specific ops for the graph."""
|
|
num_gpus = self.hps.num_gpus if self.hps.num_gpus != 0 else 1
|
|
# The learning rate schedule is dependent on the number of gpus.
|
|
boundaries = [int(20000 * i / np.sqrt(num_gpus)) for i in range(2, 5)]
|
|
values = [0.1, 0.01, 0.001, 0.0001]
|
|
self.lrn_rate = tf.train.piecewise_constant(self.global_step,
|
|
boundaries, values)
|
|
tf.summary.scalar("learning rate", self.lrn_rate)
|
|
|
|
if self.hps.optimizer == "sgd":
|
|
optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
|
|
elif self.hps.optimizer == "mom":
|
|
optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)
|
|
|
|
apply_op = optimizer.minimize(self.cost, global_step=self.global_step)
|
|
train_ops = [apply_op] + self._extra_train_ops
|
|
self.train_op = tf.group(*train_ops)
|
|
self.variables = ray.experimental.tf_utils.TensorFlowVariables(
|
|
self.train_op)
|
|
|
|
def _batch_norm(self, name, x):
|
|
"""Batch normalization."""
|
|
with tf.variable_scope(name):
|
|
params_shape = [x.get_shape()[-1]]
|
|
|
|
beta = tf.get_variable(
|
|
"beta",
|
|
params_shape,
|
|
tf.float32,
|
|
initializer=tf.constant_initializer(0.0, tf.float32))
|
|
gamma = tf.get_variable(
|
|
"gamma",
|
|
params_shape,
|
|
tf.float32,
|
|
initializer=tf.constant_initializer(1.0, tf.float32))
|
|
|
|
if self.mode == "train":
|
|
mean, variance = tf.nn.moments(x, [0, 1, 2], name="moments")
|
|
|
|
moving_mean = tf.get_variable(
|
|
"moving_mean",
|
|
params_shape,
|
|
tf.float32,
|
|
initializer=tf.constant_initializer(0.0, tf.float32),
|
|
trainable=False)
|
|
moving_variance = tf.get_variable(
|
|
"moving_variance",
|
|
params_shape,
|
|
tf.float32,
|
|
initializer=tf.constant_initializer(1.0, tf.float32),
|
|
trainable=False)
|
|
|
|
self._extra_train_ops.append(
|
|
moving_averages.assign_moving_average(
|
|
moving_mean, mean, 0.9))
|
|
self._extra_train_ops.append(
|
|
moving_averages.assign_moving_average(
|
|
moving_variance, variance, 0.9))
|
|
else:
|
|
mean = tf.get_variable(
|
|
"moving_mean",
|
|
params_shape,
|
|
tf.float32,
|
|
initializer=tf.constant_initializer(0.0, tf.float32),
|
|
trainable=False)
|
|
variance = tf.get_variable(
|
|
"moving_variance",
|
|
params_shape,
|
|
tf.float32,
|
|
initializer=tf.constant_initializer(1.0, tf.float32),
|
|
trainable=False)
|
|
tf.summary.histogram(mean.op.name, mean)
|
|
tf.summary.histogram(variance.op.name, variance)
|
|
# elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper
|
|
# net.
|
|
y = tf.nn.batch_normalization(x, mean, variance, beta, gamma,
|
|
0.001)
|
|
y.set_shape(x.get_shape())
|
|
return y
|
|
|
|
def _residual(self,
|
|
x,
|
|
in_filter,
|
|
out_filter,
|
|
stride,
|
|
activate_before_residual=False):
|
|
"""Residual unit with 2 sub layers."""
|
|
if activate_before_residual:
|
|
with tf.variable_scope("shared_activation"):
|
|
x = self._batch_norm("init_bn", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
orig_x = x
|
|
else:
|
|
with tf.variable_scope("residual_only_activation"):
|
|
orig_x = x
|
|
x = self._batch_norm("init_bn", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
|
|
with tf.variable_scope("sub1"):
|
|
x = self._conv("conv1", x, 3, in_filter, out_filter, stride)
|
|
|
|
with tf.variable_scope("sub2"):
|
|
x = self._batch_norm("bn2", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
x = self._conv("conv2", x, 3, out_filter, out_filter, [1, 1, 1, 1])
|
|
|
|
with tf.variable_scope("sub_add"):
|
|
if in_filter != out_filter:
|
|
orig_x = tf.nn.avg_pool(orig_x, stride, stride, "VALID")
|
|
orig_x = tf.pad(
|
|
orig_x,
|
|
[[0, 0], [0, 0], [0, 0], [(out_filter - in_filter) // 2,
|
|
(out_filter - in_filter) // 2]])
|
|
x += orig_x
|
|
|
|
return x
|
|
|
|
def _bottleneck_residual(self,
|
|
x,
|
|
in_filter,
|
|
out_filter,
|
|
stride,
|
|
activate_before_residual=False):
|
|
"""Bottleneck residual unit with 3 sub layers."""
|
|
if activate_before_residual:
|
|
with tf.variable_scope("common_bn_relu"):
|
|
x = self._batch_norm("init_bn", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
orig_x = x
|
|
else:
|
|
with tf.variable_scope("residual_bn_relu"):
|
|
orig_x = x
|
|
x = self._batch_norm("init_bn", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
|
|
with tf.variable_scope("sub1"):
|
|
x = self._conv("conv1", x, 1, in_filter, out_filter / 4, stride)
|
|
|
|
with tf.variable_scope("sub2"):
|
|
x = self._batch_norm("bn2", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
x = self._conv("conv2", x, 3, out_filter / 4, out_filter / 4,
|
|
[1, 1, 1, 1])
|
|
|
|
with tf.variable_scope("sub3"):
|
|
x = self._batch_norm("bn3", x)
|
|
x = self._relu(x, self.hps.relu_leakiness)
|
|
x = self._conv("conv3", x, 1, out_filter / 4, out_filter,
|
|
[1, 1, 1, 1])
|
|
|
|
with tf.variable_scope("sub_add"):
|
|
if in_filter != out_filter:
|
|
orig_x = self._conv("project", orig_x, 1, in_filter,
|
|
out_filter, stride)
|
|
x += orig_x
|
|
|
|
return x
|
|
|
|
def _decay(self):
|
|
"""L2 weight decay loss."""
|
|
costs = []
|
|
for var in tf.trainable_variables():
|
|
if var.op.name.find(r"DW") > 0:
|
|
costs.append(tf.nn.l2_loss(var))
|
|
|
|
return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
|
|
|
|
def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
|
|
"""Convolution."""
|
|
with tf.variable_scope(name):
|
|
n = filter_size * filter_size * out_filters
|
|
kernel = tf.get_variable(
|
|
"DW", [filter_size, filter_size, in_filters, out_filters],
|
|
tf.float32,
|
|
initializer=tf.random_normal_initializer(
|
|
stddev=np.sqrt(2.0 / n)))
|
|
return tf.nn.conv2d(x, kernel, strides, padding="SAME")
|
|
|
|
def _relu(self, x, leakiness=0.0):
|
|
"""Relu, with optional leaky support."""
|
|
return tf.where(tf.less(x, 0.0), leakiness * x, x, name="leaky_relu")
|
|
|
|
def _fully_connected(self, x, out_dim):
|
|
"""FullyConnected layer for final output."""
|
|
x = tf.reshape(x, [self.hps.batch_size, -1])
|
|
w = tf.get_variable(
|
|
"DW", [x.get_shape()[1], out_dim],
|
|
initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
|
|
b = tf.get_variable(
|
|
"biases", [out_dim], initializer=tf.constant_initializer())
|
|
return tf.nn.xw_plus_b(x, w, b)
|
|
|
|
def _global_avg_pool(self, x):
|
|
assert x.get_shape().ndims == 4
|
|
return tf.reduce_mean(x, [1, 2])
|