Fix bug in tensorflow tests. (#218)

* Fix bug in tensorflow tests.

* Address comment.
This commit is contained in:
Robert Nishihara 2017-01-19 20:29:05 -08:00 committed by Philipp Moritz
parent 9bb8162621
commit 7151ed5cdf

View file

@ -86,7 +86,7 @@ class TensorFlowTest(unittest.TestCase):
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
net_vars1, init1, sess1 = ray.env.net1
net_vars2, init2, sess2 = ray.env.net2
@ -108,7 +108,7 @@ class TensorFlowTest(unittest.TestCase):
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
net_vars1, init1, sess1 = ray.env.net1
net_vars2, init2, sess2 = ray.env.net2
@ -117,41 +117,32 @@ class TensorFlowTest(unittest.TestCase):
sess2.run(init2)
@ray.remote
def get_vars1():
return ray.env.net1[0].get_weights()
def set_and_get_weights(weights1, weights2):
ray.env.net1[0].set_weights(weights1)
ray.env.net2[0].set_weights(weights2)
return ray.env.net1[0].get_weights(), ray.env.net2[0].get_weights()
@ray.remote
def get_vars2():
return ray.env.net2[0].get_weights()
@ray.remote
def set_vars1(weights):
ray.env.net1[0].set_weights(weights)
@ray.remote
def set_vars2(weights):
ray.env.net2[0].set_weights(weights)
# Get the weights.
# Make sure the two networks have different weights. TODO(rkn): Note that
# equality comparisons of numpy arrays normally does not work. This only
# works because at the moment they have size 1.
weights1 = net_vars1.get_weights()
weights2 = net_vars2.get_weights()
self.assertNotEqual(weights1, weights2)
# Set the weights and get the weights, and make sure they are unchanged.
new_weights1, new_weights2 = ray.get(set_and_get_weights.remote(weights1, weights2))
self.assertEqual(weights1, new_weights1)
self.assertEqual(weights2, new_weights2)
# Swap the weights.
set_vars2.remote(weights1)
set_vars1.remote(weights2)
# Get the new weights.
new_weights1 = ray.get(get_vars1.remote())
new_weights2 = ray.get(get_vars2.remote())
self.assertNotEqual(new_weights1, new_weights2)
# Check that the weights were swapped.
self.assertEqual(weights1, new_weights2)
self.assertEqual(weights2, new_weights1)
new_weights2, new_weights1 = ray.get(set_and_get_weights.remote(weights2, weights1))
self.assertEqual(weights1, new_weights1)
self.assertEqual(weights2, new_weights2)
ray.worker.cleanup()
# This test creates an additional network on the driver so that the tensorflow
# variables on the driver and the worker differ.
def testNetworkDriverWorkerIndependent(self):
ray.init(num_workers=1)
@ -167,23 +158,15 @@ class TensorFlowTest(unittest.TestCase):
net_vars2, init2, sess2 = ray.env.net
sess2.run(init2)
# Get the weights.
weights1 = net_vars1.get_weights()
weights2 = net_vars2.get_weights()
self.assertNotEqual(weights1, weights2)
# Swap the weights.
net_vars1.set_weights(weights2)
net_vars2.set_weights(weights1)
@ray.remote
def set_and_get_weights(weights):
ray.env.net[0].set_weights(weights)
return ray.env.net[0].get_weights()
# Get the new weights.
new_weights1 = net_vars1.get_weights()
new_weights2 = net_vars2.get_weights()
self.assertNotEqual(new_weights1, new_weights2)
# Check that the weights were swapped.
self.assertEqual(weights1, new_weights2)
self.assertEqual(weights2, new_weights1)
new_weights2 = ray.get(set_and_get_weights.remote(net_vars2.get_weights()))
self.assertEqual(weights2, new_weights2)
ray.worker.cleanup()