Correctly setting the input to Train (#3853)

In the ResNetTrainActor class, the data are now exactly build using the Train flag for the cifar_input script.
This commit is contained in:
LorenzoCevolani 2019-07-27 20:08:35 +02:00 committed by Richard Liaw
parent a62c5f40f6
commit 10cbcced7e

View file

@ -107,7 +107,7 @@ class ResNetTrainActor(object):
with tf.device("/gpu:0" if num_gpus > 0 else "/cpu:0"):
# Build the model.
images, labels = cifar_input.build_input(data, hps.batch_size,
dataset, False)
dataset, True)
self.model = resnet_model.ResNet(hps, images, labels, "train")
self.model.build_graph()
config = tf.ConfigProto(allow_soft_placement=True)