diff --git a/test/tensorflow_test.py b/test/tensorflow_test.py index 1c0a43caa..62130ac1d 100644 --- a/test/tensorflow_test.py +++ b/test/tensorflow_test.py @@ -86,7 +86,7 @@ class TensorFlowTest(unittest.TestCase): ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) - + net_vars1, init1, sess1 = ray.env.net1 net_vars2, init2, sess2 = ray.env.net2 @@ -108,7 +108,7 @@ class TensorFlowTest(unittest.TestCase): ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer) - + net_vars1, init1, sess1 = ray.env.net1 net_vars2, init2, sess2 = ray.env.net2 @@ -117,41 +117,32 @@ class TensorFlowTest(unittest.TestCase): sess2.run(init2) @ray.remote - def get_vars1(): - return ray.env.net1[0].get_weights() + def set_and_get_weights(weights1, weights2): + ray.env.net1[0].set_weights(weights1) + ray.env.net2[0].set_weights(weights2) + return ray.env.net1[0].get_weights(), ray.env.net2[0].get_weights() - @ray.remote - def get_vars2(): - return ray.env.net2[0].get_weights() - - @ray.remote - def set_vars1(weights): - ray.env.net1[0].set_weights(weights) - - @ray.remote - def set_vars2(weights): - ray.env.net2[0].set_weights(weights) - - # Get the weights. + # Make sure the two networks have different weights. TODO(rkn): Note that + # equality comparisons of numpy arrays normally does not work. This only + # works because at the moment they have size 1. weights1 = net_vars1.get_weights() weights2 = net_vars2.get_weights() self.assertNotEqual(weights1, weights2) + # Set the weights and get the weights, and make sure they are unchanged. + new_weights1, new_weights2 = ray.get(set_and_get_weights.remote(weights1, weights2)) + self.assertEqual(weights1, new_weights1) + self.assertEqual(weights2, new_weights2) + # Swap the weights. - set_vars2.remote(weights1) - set_vars1.remote(weights2) - - # Get the new weights. - new_weights1 = ray.get(get_vars1.remote()) - new_weights2 = ray.get(get_vars2.remote()) - self.assertNotEqual(new_weights1, new_weights2) - - # Check that the weights were swapped. - self.assertEqual(weights1, new_weights2) - self.assertEqual(weights2, new_weights1) + new_weights2, new_weights1 = ray.get(set_and_get_weights.remote(weights2, weights1)) + self.assertEqual(weights1, new_weights1) + self.assertEqual(weights2, new_weights2) ray.worker.cleanup() + # This test creates an additional network on the driver so that the tensorflow + # variables on the driver and the worker differ. def testNetworkDriverWorkerIndependent(self): ray.init(num_workers=1) @@ -167,23 +158,15 @@ class TensorFlowTest(unittest.TestCase): net_vars2, init2, sess2 = ray.env.net sess2.run(init2) - # Get the weights. - weights1 = net_vars1.get_weights() weights2 = net_vars2.get_weights() - self.assertNotEqual(weights1, weights2) - # Swap the weights. - net_vars1.set_weights(weights2) - net_vars2.set_weights(weights1) + @ray.remote + def set_and_get_weights(weights): + ray.env.net[0].set_weights(weights) + return ray.env.net[0].get_weights() - # Get the new weights. - new_weights1 = net_vars1.get_weights() - new_weights2 = net_vars2.get_weights() - self.assertNotEqual(new_weights1, new_weights2) - - # Check that the weights were swapped. - self.assertEqual(weights1, new_weights2) - self.assertEqual(weights2, new_weights1) + new_weights2 = ray.get(set_and_get_weights.remote(net_vars2.get_weights())) + self.assertEqual(weights2, new_weights2) ray.worker.cleanup()