mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Fix bug in tensorflow tests. (#218)
* Fix bug in tensorflow tests. * Address comment.
This commit is contained in:
parent
9bb8162621
commit
7151ed5cdf
1 changed files with 25 additions and 42 deletions
|
@ -86,7 +86,7 @@ class TensorFlowTest(unittest.TestCase):
|
|||
|
||||
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
|
||||
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
|
||||
|
||||
|
||||
net_vars1, init1, sess1 = ray.env.net1
|
||||
net_vars2, init2, sess2 = ray.env.net2
|
||||
|
||||
|
@ -108,7 +108,7 @@ class TensorFlowTest(unittest.TestCase):
|
|||
|
||||
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
|
||||
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
|
||||
|
||||
|
||||
net_vars1, init1, sess1 = ray.env.net1
|
||||
net_vars2, init2, sess2 = ray.env.net2
|
||||
|
||||
|
@ -117,41 +117,32 @@ class TensorFlowTest(unittest.TestCase):
|
|||
sess2.run(init2)
|
||||
|
||||
@ray.remote
|
||||
def get_vars1():
|
||||
return ray.env.net1[0].get_weights()
|
||||
def set_and_get_weights(weights1, weights2):
|
||||
ray.env.net1[0].set_weights(weights1)
|
||||
ray.env.net2[0].set_weights(weights2)
|
||||
return ray.env.net1[0].get_weights(), ray.env.net2[0].get_weights()
|
||||
|
||||
@ray.remote
|
||||
def get_vars2():
|
||||
return ray.env.net2[0].get_weights()
|
||||
|
||||
@ray.remote
|
||||
def set_vars1(weights):
|
||||
ray.env.net1[0].set_weights(weights)
|
||||
|
||||
@ray.remote
|
||||
def set_vars2(weights):
|
||||
ray.env.net2[0].set_weights(weights)
|
||||
|
||||
# Get the weights.
|
||||
# Make sure the two networks have different weights. TODO(rkn): Note that
|
||||
# equality comparisons of numpy arrays normally does not work. This only
|
||||
# works because at the moment they have size 1.
|
||||
weights1 = net_vars1.get_weights()
|
||||
weights2 = net_vars2.get_weights()
|
||||
self.assertNotEqual(weights1, weights2)
|
||||
|
||||
# Set the weights and get the weights, and make sure they are unchanged.
|
||||
new_weights1, new_weights2 = ray.get(set_and_get_weights.remote(weights1, weights2))
|
||||
self.assertEqual(weights1, new_weights1)
|
||||
self.assertEqual(weights2, new_weights2)
|
||||
|
||||
# Swap the weights.
|
||||
set_vars2.remote(weights1)
|
||||
set_vars1.remote(weights2)
|
||||
|
||||
# Get the new weights.
|
||||
new_weights1 = ray.get(get_vars1.remote())
|
||||
new_weights2 = ray.get(get_vars2.remote())
|
||||
self.assertNotEqual(new_weights1, new_weights2)
|
||||
|
||||
# Check that the weights were swapped.
|
||||
self.assertEqual(weights1, new_weights2)
|
||||
self.assertEqual(weights2, new_weights1)
|
||||
new_weights2, new_weights1 = ray.get(set_and_get_weights.remote(weights2, weights1))
|
||||
self.assertEqual(weights1, new_weights1)
|
||||
self.assertEqual(weights2, new_weights2)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
# This test creates an additional network on the driver so that the tensorflow
|
||||
# variables on the driver and the worker differ.
|
||||
def testNetworkDriverWorkerIndependent(self):
|
||||
ray.init(num_workers=1)
|
||||
|
||||
|
@ -167,23 +158,15 @@ class TensorFlowTest(unittest.TestCase):
|
|||
net_vars2, init2, sess2 = ray.env.net
|
||||
sess2.run(init2)
|
||||
|
||||
# Get the weights.
|
||||
weights1 = net_vars1.get_weights()
|
||||
weights2 = net_vars2.get_weights()
|
||||
self.assertNotEqual(weights1, weights2)
|
||||
|
||||
# Swap the weights.
|
||||
net_vars1.set_weights(weights2)
|
||||
net_vars2.set_weights(weights1)
|
||||
@ray.remote
|
||||
def set_and_get_weights(weights):
|
||||
ray.env.net[0].set_weights(weights)
|
||||
return ray.env.net[0].get_weights()
|
||||
|
||||
# Get the new weights.
|
||||
new_weights1 = net_vars1.get_weights()
|
||||
new_weights2 = net_vars2.get_weights()
|
||||
self.assertNotEqual(new_weights1, new_weights2)
|
||||
|
||||
# Check that the weights were swapped.
|
||||
self.assertEqual(weights1, new_weights2)
|
||||
self.assertEqual(weights2, new_weights1)
|
||||
new_weights2 = ray.get(set_and_get_weights.remote(net_vars2.get_weights()))
|
||||
self.assertEqual(weights2, new_weights2)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue