mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
Fix MNIST downloading problems in parameter server examples. (#1457)
* Fix MNIST downloading problems in parameter server examples. * Improve seeding. * Fixes.
This commit is contained in:
parent
0a01d3c71f
commit
e96acc26f7
3 changed files with 18 additions and 11 deletions
|
@ -3,11 +3,9 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
import model
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the asynchronous parameter "
|
||||
|
@ -35,9 +33,9 @@ class ParameterServer(object):
|
|||
|
||||
|
||||
@ray.remote
|
||||
def worker_task(ps, batch_size=50):
|
||||
def worker_task(ps, worker_index, batch_size=50):
|
||||
# Download MNIST.
|
||||
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
|
||||
mnist = model.download_mnist_retry(seed=worker_index)
|
||||
|
||||
# Initialize the model.
|
||||
net = model.SimpleCNN()
|
||||
|
@ -65,10 +63,10 @@ if __name__ == "__main__":
|
|||
ps = ParameterServer.remote(all_keys, all_values)
|
||||
|
||||
# Start some training tasks.
|
||||
worker_tasks = [worker_task.remote(ps) for _ in range(args.num_workers)]
|
||||
worker_tasks = [worker_task.remote(ps, i) for i in range(args.num_workers)]
|
||||
|
||||
# Download MNIST.
|
||||
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
|
||||
mnist = model.download_mnist_retry()
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
|
|
|
@ -8,6 +8,18 @@ from __future__ import print_function
|
|||
|
||||
import ray
|
||||
import tensorflow as tf
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
import time
|
||||
|
||||
|
||||
def download_mnist_retry(seed=0, max_num_retries=20):
|
||||
for _ in range(max_num_retries):
|
||||
try:
|
||||
return input_data.read_data_sets("MNIST_data", one_hot=True,
|
||||
seed=seed)
|
||||
except tf.errors.AlreadyExistsError:
|
||||
time.sleep(1)
|
||||
raise Exception("Failed to download MNIST.")
|
||||
|
||||
|
||||
class SimpleCNN(object):
|
||||
|
|
|
@ -3,9 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
import ray
|
||||
import model
|
||||
|
@ -36,8 +34,7 @@ class Worker(object):
|
|||
def __init__(self, worker_index, batch_size=50):
|
||||
self.worker_index = worker_index
|
||||
self.batch_size = batch_size
|
||||
self.mnist = input_data.read_data_sets("MNIST_data", one_hot=True,
|
||||
seed=worker_index)
|
||||
self.mnist = model.download_mnist_retry(seed=worker_index)
|
||||
self.net = model.SimpleCNN()
|
||||
|
||||
def compute_gradients(self, weights):
|
||||
|
@ -60,7 +57,7 @@ if __name__ == "__main__":
|
|||
for worker_index in range(args.num_workers)]
|
||||
|
||||
# Download MNIST.
|
||||
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
|
||||
mnist = model.download_mnist_retry()
|
||||
|
||||
i = 0
|
||||
current_weights = ps.get_weights.remote()
|
||||
|
|
Loading…
Add table
Reference in a new issue