Fix bug in tensorflow tests. (#218)

* Fix bug in tensorflow tests.

* Address comment.
This commit is contained in:
Robert Nishihara 2017-01-19 20:29:05 -08:00 committed by Philipp Moritz
parent 9bb8162621
commit 7151ed5cdf

View file

@ -86,7 +86,7 @@ class TensorFlowTest(unittest.TestCase):
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
net_vars1, init1, sess1 = ray.env.net1 net_vars1, init1, sess1 = ray.env.net1
net_vars2, init2, sess2 = ray.env.net2 net_vars2, init2, sess2 = ray.env.net2
@ -108,7 +108,7 @@ class TensorFlowTest(unittest.TestCase):
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
net_vars1, init1, sess1 = ray.env.net1 net_vars1, init1, sess1 = ray.env.net1
net_vars2, init2, sess2 = ray.env.net2 net_vars2, init2, sess2 = ray.env.net2
@ -117,41 +117,32 @@ class TensorFlowTest(unittest.TestCase):
sess2.run(init2) sess2.run(init2)
@ray.remote @ray.remote
def get_vars1(): def set_and_get_weights(weights1, weights2):
return ray.env.net1[0].get_weights() ray.env.net1[0].set_weights(weights1)
ray.env.net2[0].set_weights(weights2)
return ray.env.net1[0].get_weights(), ray.env.net2[0].get_weights()
@ray.remote # Make sure the two networks have different weights. TODO(rkn): Note that
def get_vars2(): # equality comparisons of numpy arrays normally does not work. This only
return ray.env.net2[0].get_weights() # works because at the moment they have size 1.
@ray.remote
def set_vars1(weights):
ray.env.net1[0].set_weights(weights)
@ray.remote
def set_vars2(weights):
ray.env.net2[0].set_weights(weights)
# Get the weights.
weights1 = net_vars1.get_weights() weights1 = net_vars1.get_weights()
weights2 = net_vars2.get_weights() weights2 = net_vars2.get_weights()
self.assertNotEqual(weights1, weights2) self.assertNotEqual(weights1, weights2)
# Set the weights and get the weights, and make sure they are unchanged.
new_weights1, new_weights2 = ray.get(set_and_get_weights.remote(weights1, weights2))
self.assertEqual(weights1, new_weights1)
self.assertEqual(weights2, new_weights2)
# Swap the weights. # Swap the weights.
set_vars2.remote(weights1) new_weights2, new_weights1 = ray.get(set_and_get_weights.remote(weights2, weights1))
set_vars1.remote(weights2) self.assertEqual(weights1, new_weights1)
self.assertEqual(weights2, new_weights2)
# Get the new weights.
new_weights1 = ray.get(get_vars1.remote())
new_weights2 = ray.get(get_vars2.remote())
self.assertNotEqual(new_weights1, new_weights2)
# Check that the weights were swapped.
self.assertEqual(weights1, new_weights2)
self.assertEqual(weights2, new_weights1)
ray.worker.cleanup() ray.worker.cleanup()
# This test creates an additional network on the driver so that the tensorflow
# variables on the driver and the worker differ.
def testNetworkDriverWorkerIndependent(self): def testNetworkDriverWorkerIndependent(self):
ray.init(num_workers=1) ray.init(num_workers=1)
@ -167,23 +158,15 @@ class TensorFlowTest(unittest.TestCase):
net_vars2, init2, sess2 = ray.env.net net_vars2, init2, sess2 = ray.env.net
sess2.run(init2) sess2.run(init2)
# Get the weights.
weights1 = net_vars1.get_weights()
weights2 = net_vars2.get_weights() weights2 = net_vars2.get_weights()
self.assertNotEqual(weights1, weights2)
# Swap the weights. @ray.remote
net_vars1.set_weights(weights2) def set_and_get_weights(weights):
net_vars2.set_weights(weights1) ray.env.net[0].set_weights(weights)
return ray.env.net[0].get_weights()
# Get the new weights. new_weights2 = ray.get(set_and_get_weights.remote(net_vars2.get_weights()))
new_weights1 = net_vars1.get_weights() self.assertEqual(weights2, new_weights2)
new_weights2 = net_vars2.get_weights()
self.assertNotEqual(new_weights1, new_weights2)
# Check that the weights were swapped.
self.assertEqual(weights1, new_weights2)
self.assertEqual(weights2, new_weights1)
ray.worker.cleanup() ray.worker.cleanup()