Fix MNIST downloading problems in parameter server examples. (#1457)

* Fix MNIST downloading problems in parameter server examples.

* Improve seeding.

* Fixes.
This commit is contained in:
Robert Nishihara 2018-01-25 14:14:37 -08:00 committed by Richard Liaw
parent 0a01d3c71f
commit e96acc26f7
3 changed files with 18 additions and 11 deletions

View file

@ -3,11 +3,9 @@ from __future__ import division
from __future__ import print_function
import argparse
from tensorflow.examples.tutorials.mnist import input_data
import time
import ray
import model
parser = argparse.ArgumentParser(description="Run the asynchronous parameter "
@ -35,9 +33,9 @@ class ParameterServer(object):
@ray.remote
def worker_task(ps, batch_size=50):
def worker_task(ps, worker_index, batch_size=50):
# Download MNIST.
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
mnist = model.download_mnist_retry(seed=worker_index)
# Initialize the model.
net = model.SimpleCNN()
@ -65,10 +63,10 @@ if __name__ == "__main__":
ps = ParameterServer.remote(all_keys, all_values)
# Start some training tasks.
worker_tasks = [worker_task.remote(ps) for _ in range(args.num_workers)]
worker_tasks = [worker_task.remote(ps, i) for i in range(args.num_workers)]
# Download MNIST.
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
mnist = model.download_mnist_retry()
i = 0
while True:

View file

@ -8,6 +8,18 @@ from __future__ import print_function
import ray
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
def download_mnist_retry(seed=0, max_num_retries=20):
for _ in range(max_num_retries):
try:
return input_data.read_data_sets("MNIST_data", one_hot=True,
seed=seed)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
class SimpleCNN(object):

View file

@ -3,9 +3,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import ray
import model
@ -36,8 +34,7 @@ class Worker(object):
def __init__(self, worker_index, batch_size=50):
self.worker_index = worker_index
self.batch_size = batch_size
self.mnist = input_data.read_data_sets("MNIST_data", one_hot=True,
seed=worker_index)
self.mnist = model.download_mnist_retry(seed=worker_index)
self.net = model.SimpleCNN()
def compute_gradients(self, weights):
@ -60,7 +57,7 @@ if __name__ == "__main__":
for worker_index in range(args.num_workers)]
# Download MNIST.
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
mnist = model.download_mnist_retry()
i = 0
current_weights = ps.get_weights.remote()