simplify code now that pull holds a reference (#146)

This commit is contained in:
Robert Nishihara 2016-06-22 11:24:59 -07:00 committed by Philipp Moritz
parent aa24907f94
commit fc5c40fb95

View file

@ -75,34 +75,25 @@ class ArraysDistTest(unittest.TestCase):
services.start_singlenode_cluster(return_drivers=False, num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path) services.start_singlenode_cluster(return_drivers=False, num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path)
x = da.zeros([9, 25, 51], "float") x = da.zeros([9, 25, 51], "float")
y = da.assemble(x) self.assertTrue(np.alltrue(ray.pull(da.assemble(x)) == np.zeros([9, 25, 51])))
self.assertTrue(np.alltrue(ray.pull(y) == np.zeros([9, 25, 51])))
x = da.ones([11, 25, 49], dtype_name="float") x = da.ones([11, 25, 49], dtype_name="float")
y = da.assemble(x) self.assertTrue(np.alltrue(ray.pull(da.assemble(x)) == np.ones([11, 25, 49])))
self.assertTrue(np.alltrue(ray.pull(y) == np.ones([11, 25, 49])))
x = da.random.normal([11, 25, 49]) x = da.random.normal([11, 25, 49])
y = da.copy(x) y = da.copy(x)
z = da.assemble(x) self.assertTrue(np.alltrue(ray.pull(da.assemble(x)) == ray.pull(da.assemble(y))))
w = da.assemble(y)
self.assertTrue(np.alltrue(ray.pull(z) == ray.pull(w)))
x = da.eye(25, dtype_name="float") x = da.eye(25, dtype_name="float")
y = da.assemble(x) self.assertTrue(np.alltrue(ray.pull(da.assemble(x)) == np.eye(25)))
self.assertTrue(np.alltrue(ray.pull(y) == np.eye(25)))
x = da.random.normal([25, 49]) x = da.random.normal([25, 49])
y = da.triu(x) y = da.triu(x)
z = da.assemble(y) self.assertTrue(np.alltrue(ray.pull(da.assemble(y)) == np.triu(ray.pull(da.assemble(x)))))
w = da.assemble(x)
self.assertTrue(np.alltrue(ray.pull(z) == np.triu(ray.pull(w))))
x = da.random.normal([25, 49]) x = da.random.normal([25, 49])
y = da.tril(x) y = da.tril(x)
z = da.assemble(y) self.assertTrue(np.alltrue(ray.pull(da.assemble(y)) == np.tril(ray.pull(da.assemble(x)))))
w = da.assemble(x)
self.assertTrue(np.alltrue(ray.pull(z) == np.tril(ray.pull(w))))
x = da.random.normal([25, 49]) x = da.random.normal([25, 49])
y = da.random.normal([49, 18]) y = da.random.normal([49, 18])
@ -117,35 +108,25 @@ class ArraysDistTest(unittest.TestCase):
x = da.random.normal([23, 42]) x = da.random.normal([23, 42])
y = da.random.normal([23, 42]) y = da.random.normal([23, 42])
z = da.add(x, y) z = da.add(x, y)
z_full = da.assemble(z) self.assertTrue(np.allclose(ray.pull(da.assemble(z)), ray.pull(da.assemble(x)) + ray.pull(da.assemble(y))))
x_full = da.assemble(x)
y_full = da.assemble(y)
self.assertTrue(np.allclose(ray.pull(z_full), ray.pull(x_full) + ray.pull(y_full)))
# test subtract # test subtract
x = da.random.normal([33, 40]) x = da.random.normal([33, 40])
y = da.random.normal([33, 40]) y = da.random.normal([33, 40])
z = da.subtract(x, y) z = da.subtract(x, y)
z_full = da.assemble(z) self.assertTrue(np.allclose(ray.pull(da.assemble(z)), ray.pull(da.assemble(x)) - ray.pull(da.assemble(y))))
x_full = da.assemble(x)
y_full = da.assemble(y)
self.assertTrue(np.allclose(ray.pull(z_full), ray.pull(x_full) - ray.pull(y_full)))
# test transpose # test transpose
x = da.random.normal([234, 432]) x = da.random.normal([234, 432])
y = da.transpose(x) y = da.transpose(x)
x_full = da.assemble(x) self.assertTrue(np.alltrue(ray.pull(da.assemble(x)).T == ray.pull(da.assemble(y))))
y_full = da.assemble(y)
self.assertTrue(np.alltrue(ray.pull(x_full).T == ray.pull(y_full)))
# test numpy_to_dist # test numpy_to_dist
x = da.random.normal([23, 45]) x = da.random.normal([23, 45])
y = da.assemble(x) y = da.assemble(x)
z = da.numpy_to_dist(y) z = da.numpy_to_dist(y)
w = da.assemble(z) w = da.assemble(z)
x_full = da.assemble(x) self.assertTrue(np.alltrue(ray.pull(da.assemble(x)) == ray.pull(da.assemble(z))))
z_full = da.assemble(z)
self.assertTrue(np.alltrue(ray.pull(x_full) == ray.pull(z_full)))
self.assertTrue(np.alltrue(ray.pull(y) == ray.pull(w))) self.assertTrue(np.alltrue(ray.pull(y) == ray.pull(w)))
# test da.tsqr # test da.tsqr
@ -153,10 +134,8 @@ class ArraysDistTest(unittest.TestCase):
x = da.random.normal(shape) x = da.random.normal(shape)
K = min(shape) K = min(shape)
q, r = da.linalg.tsqr(x) q, r = da.linalg.tsqr(x)
x_full = da.assemble(x) x_val = ray.pull(da.assemble(x))
x_val = ray.pull(x_full) q_val = ray.pull(da.assemble(q))
q_full = da.assemble(q)
q_val = ray.pull(q_full)
r_val = ray.pull(r) r_val = ray.pull(r)
self.assertTrue(r_val.shape == (K, shape[1])) self.assertTrue(r_val.shape == (K, shape[1]))
self.assertTrue(np.alltrue(r_val == np.triu(r_val))) self.assertTrue(np.alltrue(r_val == np.triu(r_val)))
@ -173,8 +152,7 @@ class ArraysDistTest(unittest.TestCase):
l, u, s = da.linalg.modified_lu(da.numpy_to_dist(q)) l, u, s = da.linalg.modified_lu(da.numpy_to_dist(q))
q_val = ray.pull(q) q_val = ray.pull(q)
r_val = ray.pull(r) r_val = ray.pull(r)
l_full = da.assemble(l) l_val = ray.pull(da.assemble(l))
l_val = ray.pull(l_full)
u_val = ray.pull(u) u_val = ray.pull(u)
s_val = ray.pull(s) s_val = ray.pull(s)
s_mat = np.zeros((d1, d2)) s_mat = np.zeros((d1, d2))
@ -192,10 +170,8 @@ class ArraysDistTest(unittest.TestCase):
print "testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + str(d2) print "testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + str(d2)
a = da.random.normal([d1, d2]) a = da.random.normal([d1, d2])
y, t, y_top, r = da.linalg.tsqr_hr(a) y, t, y_top, r = da.linalg.tsqr_hr(a)
a_full = da.assemble(a) a_val = ray.pull(da.assemble(a))
a_val = ray.pull(a_full) y_val = ray.pull(da.assemble(y))
y_full = da.assemble(y)
y_val = ray.pull(y_full)
t_val = ray.pull(t) t_val = ray.pull(t)
y_top_val = ray.pull(y_top) y_top_val = ray.pull(y_top)
r_val = ray.pull(r) r_val = ray.pull(r)
@ -213,13 +189,9 @@ class ArraysDistTest(unittest.TestCase):
a = da.random.normal([d1, d2]) a = da.random.normal([d1, d2])
K = min(d1, d2) K = min(d1, d2)
q, r = da.linalg.qr(a) q, r = da.linalg.qr(a)
a_full = da.assemble(a) a_val = ray.pull(da.assemble(a))
q_full = da.assemble(q) q_val = ray.pull(da.assemble(q))
r_full = da.assemble(r) r_val = ray.pull(da.assemble(r))
a_val = ray.pull(a_full)
q_val = ray.pull(q_full)
r_val = ray.pull(r_full)
self.assertTrue(q_val.shape == (d1, K)) self.assertTrue(q_val.shape == (d1, K))
self.assertTrue(r_val.shape == (K, d2)) self.assertTrue(r_val.shape == (K, d2))
self.assertTrue(np.allclose(np.dot(q_val.T, q_val), np.eye(K))) self.assertTrue(np.allclose(np.dot(q_val.T, q_val), np.eye(K)))