mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Fix MNIST downloading problems in parameter server examples. * Improve seeding. * Fixes.
76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import numpy as np
|
|
|
|
import ray
|
|
import model
|
|
|
|
parser = argparse.ArgumentParser(description="Run the synchronous parameter "
|
|
"server example.")
|
|
parser.add_argument("--num-workers", default=4, type=int,
|
|
help="The number of workers to use.")
|
|
parser.add_argument("--redis-address", default=None, type=str,
|
|
help="The Redis address of the cluster.")
|
|
|
|
|
|
@ray.remote
|
|
class ParameterServer(object):
|
|
def __init__(self, learning_rate):
|
|
self.net = model.SimpleCNN(learning_rate=learning_rate)
|
|
|
|
def apply_gradients(self, *gradients):
|
|
self.net.apply_gradients(np.mean(gradients, axis=0))
|
|
return self.net.variables.get_flat()
|
|
|
|
def get_weights(self):
|
|
return self.net.variables.get_flat()
|
|
|
|
|
|
@ray.remote
|
|
class Worker(object):
|
|
def __init__(self, worker_index, batch_size=50):
|
|
self.worker_index = worker_index
|
|
self.batch_size = batch_size
|
|
self.mnist = model.download_mnist_retry(seed=worker_index)
|
|
self.net = model.SimpleCNN()
|
|
|
|
def compute_gradients(self, weights):
|
|
self.net.variables.set_flat(weights)
|
|
xs, ys = self.mnist.train.next_batch(self.batch_size)
|
|
return self.net.compute_gradients(xs, ys)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
ray.init(redis_address=args.redis_address)
|
|
|
|
# Create a parameter server.
|
|
net = model.SimpleCNN()
|
|
ps = ParameterServer.remote(1e-4 * args.num_workers)
|
|
|
|
# Create workers.
|
|
workers = [Worker.remote(worker_index)
|
|
for worker_index in range(args.num_workers)]
|
|
|
|
# Download MNIST.
|
|
mnist = model.download_mnist_retry()
|
|
|
|
i = 0
|
|
current_weights = ps.get_weights.remote()
|
|
while True:
|
|
# Compute and apply gradients.
|
|
gradients = [worker.compute_gradients.remote(current_weights)
|
|
for worker in workers]
|
|
current_weights = ps.apply_gradients.remote(*gradients)
|
|
|
|
if i % 10 == 0:
|
|
# Evaluate the current model.
|
|
net.variables.set_flat(ray.get(current_weights))
|
|
test_xs, test_ys = mnist.test.next_batch(1000)
|
|
accuracy = net.compute_accuracy(test_xs, test_ys)
|
|
print("Iteration {}: accuracy is {}".format(i, accuracy))
|
|
i += 1
|