2016-02-22 13:55:06 -08:00
|
|
|
import unittest
|
2016-03-10 12:40:05 -08:00
|
|
|
import orchpy
|
2016-03-16 18:11:43 -07:00
|
|
|
import orchpy.serialization as serialization
|
2016-02-22 13:55:06 -08:00
|
|
|
import orchpy.services as services
|
|
|
|
import orchpy.worker as worker
|
|
|
|
import numpy as np
|
|
|
|
import time
|
2016-03-01 01:02:08 -08:00
|
|
|
import subprocess32 as subprocess
|
|
|
|
import os
|
|
|
|
|
|
|
|
from google.protobuf.text_format import *
|
2016-02-22 13:55:06 -08:00
|
|
|
|
|
|
|
import orchestra_pb2
|
|
|
|
import types_pb2
|
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
import test_functions
|
|
|
|
import arrays.single as single
|
|
|
|
import arrays.dist as dist
|
|
|
|
|
2016-03-10 12:35:31 -08:00
|
|
|
class SerializationTest(unittest.TestCase):
|
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
def roundTripTest(self, worker, data):
|
|
|
|
serialized, _ = serialization.serialize(worker.handle, data)
|
|
|
|
result = serialization.deserialize(worker.handle, serialized)
|
2016-03-10 12:35:31 -08:00
|
|
|
self.assertEqual(data, result)
|
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
def numpyTypeTest(self, worker, typ):
|
2016-03-15 13:06:51 -07:00
|
|
|
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
|
2016-04-18 13:05:36 -07:00
|
|
|
b, _ = serialization.serialize(worker.handle, a)
|
|
|
|
c = serialization.deserialize(worker.handle, b)
|
2016-03-15 13:06:51 -07:00
|
|
|
self.assertTrue((a == c).all())
|
|
|
|
|
2016-03-10 12:35:31 -08:00
|
|
|
def testSerialize(self):
|
2016-04-23 10:45:01 -07:00
|
|
|
[w] = services.start_cluster(return_drivers=True)
|
2016-03-10 12:35:31 -08:00
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
self.roundTripTest(w, [1, "hello", 3.0])
|
|
|
|
self.roundTripTest(w, 42)
|
|
|
|
self.roundTripTest(w, "hello world")
|
|
|
|
self.roundTripTest(w, 42.0)
|
|
|
|
self.roundTripTest(w, (1.0, "hi"))
|
|
|
|
|
|
|
|
self.roundTripTest(w, {"hello" : "world", 1: 42, 1.0: 45})
|
|
|
|
self.roundTripTest(w, {})
|
2016-03-24 23:35:38 -07:00
|
|
|
|
2016-03-10 12:35:31 -08:00
|
|
|
a = np.zeros((100, 100))
|
2016-04-18 13:05:36 -07:00
|
|
|
res, _ = serialization.serialize(w.handle, a)
|
|
|
|
b = serialization.deserialize(w.handle, res)
|
2016-03-10 12:35:31 -08:00
|
|
|
self.assertTrue((a == b).all())
|
2016-03-08 16:14:02 -08:00
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
self.numpyTypeTest(w, 'int8')
|
|
|
|
self.numpyTypeTest(w, 'uint8')
|
2016-03-15 13:06:51 -07:00
|
|
|
# self.numpyTypeTest('int16') # TODO(pcm): implement this
|
|
|
|
# self.numpyTypeTest('int32') # TODO(pcm): implement this
|
2016-04-18 13:05:36 -07:00
|
|
|
self.numpyTypeTest(w, 'float32')
|
|
|
|
self.numpyTypeTest(w, 'float64')
|
|
|
|
|
|
|
|
ref0 = orchpy.push(0, w)
|
|
|
|
ref1 = orchpy.push(0, w)
|
|
|
|
ref2 = orchpy.push(0, w)
|
|
|
|
ref3 = orchpy.push(0, w)
|
|
|
|
a = np.array([[ref0, ref1], [ref2, ref3]])
|
|
|
|
capsule, _ = serialization.serialize(w.handle, a)
|
|
|
|
result = serialization.deserialize(w.handle, capsule)
|
2016-03-15 13:06:51 -07:00
|
|
|
self.assertTrue((a == result).all())
|
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
services.cleanup()
|
2016-03-12 15:25:45 -08:00
|
|
|
|
2016-02-22 13:55:06 -08:00
|
|
|
class ObjStoreTest(unittest.TestCase):
|
|
|
|
|
2016-03-15 13:06:51 -07:00
|
|
|
# Test setting up object stores, transfering data between them and retrieving data to a client
|
2016-02-22 13:55:06 -08:00
|
|
|
def testObjStore(self):
|
2016-04-23 10:45:01 -07:00
|
|
|
[w1, w2] = services.start_cluster(return_drivers=True, num_objstores=2, num_workers_per_objstore=0)
|
2016-02-22 13:55:06 -08:00
|
|
|
|
2016-03-09 11:40:36 -08:00
|
|
|
# pushing and pulling an object shouldn't change it
|
|
|
|
for data in ["h", "h" * 10000, 0, 0.0]:
|
2016-04-23 10:45:01 -07:00
|
|
|
objref = orchpy.push(data, w1)
|
|
|
|
result = orchpy.pull(objref, w1)
|
2016-03-09 11:40:36 -08:00
|
|
|
self.assertEqual(result, data)
|
|
|
|
|
|
|
|
# pushing an object, shipping it to another worker, and pulling it shouldn't change it
|
2016-03-17 22:32:31 -07:00
|
|
|
# for data in ["h", "h" * 10000, 0, 0.0]:
|
|
|
|
# objref = worker.push(data, worker1)
|
|
|
|
# response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.val, objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS)
|
|
|
|
# result = worker.pull(objref, worker2)
|
|
|
|
# self.assertEqual(result, data)
|
2016-02-22 13:55:06 -08:00
|
|
|
|
2016-03-01 01:02:08 -08:00
|
|
|
services.cleanup()
|
2016-02-22 16:06:16 -08:00
|
|
|
|
2016-02-22 13:55:06 -08:00
|
|
|
class SchedulerTest(unittest.TestCase):
|
|
|
|
|
2016-03-01 01:02:08 -08:00
|
|
|
def testCall(self):
|
2016-03-09 11:40:36 -08:00
|
|
|
test_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
test_path = os.path.join(test_dir, "testrecv.py")
|
2016-04-23 10:45:01 -07:00
|
|
|
[w] = services.start_cluster(return_drivers=True, num_workers_per_objstore=1, worker_path=test_path)
|
2016-02-22 13:55:06 -08:00
|
|
|
|
2016-03-10 14:40:46 -08:00
|
|
|
value_before = "test_string"
|
2016-04-18 13:05:36 -07:00
|
|
|
objref = w.remote_call("test_functions.print_string", [value_before])
|
2016-03-10 14:40:46 -08:00
|
|
|
|
|
|
|
time.sleep(0.2)
|
|
|
|
|
2016-04-05 00:34:23 -07:00
|
|
|
value_after = orchpy.pull(objref[0], w)
|
2016-03-10 20:10:49 -08:00
|
|
|
self.assertEqual(value_before, value_after)
|
2016-02-22 13:55:06 -08:00
|
|
|
|
2016-03-01 01:02:08 -08:00
|
|
|
time.sleep(0.1)
|
2016-02-22 13:55:06 -08:00
|
|
|
|
2016-03-09 11:40:36 -08:00
|
|
|
services.cleanup()
|
2016-02-22 13:55:06 -08:00
|
|
|
|
2016-03-10 14:40:46 -08:00
|
|
|
class WorkerTest(unittest.TestCase):
|
|
|
|
|
|
|
|
def testPushPull(self):
|
2016-04-23 10:45:01 -07:00
|
|
|
[w] = services.start_cluster(return_drivers=True)
|
2016-03-10 14:40:46 -08:00
|
|
|
|
|
|
|
for i in range(100):
|
|
|
|
value_before = i * 10 ** 6
|
2016-04-05 00:34:23 -07:00
|
|
|
objref = orchpy.push(value_before, w)
|
|
|
|
value_after = orchpy.pull(objref, w)
|
2016-03-10 14:40:46 -08:00
|
|
|
self.assertEqual(value_before, value_after)
|
|
|
|
|
|
|
|
for i in range(100):
|
|
|
|
value_before = i * 10 ** 6 * 1.0
|
2016-04-05 00:34:23 -07:00
|
|
|
objref = orchpy.push(value_before, w)
|
|
|
|
value_after = orchpy.pull(objref, w)
|
2016-03-10 14:40:46 -08:00
|
|
|
self.assertEqual(value_before, value_after)
|
|
|
|
|
|
|
|
for i in range(100):
|
|
|
|
value_before = "h" * i
|
2016-04-05 00:34:23 -07:00
|
|
|
objref = orchpy.push(value_before, w)
|
|
|
|
value_after = orchpy.pull(objref, w)
|
2016-03-10 14:40:46 -08:00
|
|
|
self.assertEqual(value_before, value_after)
|
|
|
|
|
|
|
|
for i in range(100):
|
|
|
|
value_before = [1] * i
|
2016-04-05 00:34:23 -07:00
|
|
|
objref = orchpy.push(value_before, w)
|
|
|
|
value_after = orchpy.pull(objref, w)
|
2016-03-10 14:40:46 -08:00
|
|
|
self.assertEqual(value_before, value_after)
|
|
|
|
|
|
|
|
services.cleanup()
|
|
|
|
|
2016-04-05 00:34:23 -07:00
|
|
|
class APITest(unittest.TestCase):
|
|
|
|
|
|
|
|
def testObjRefAliasing(self):
|
|
|
|
test_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
test_path = os.path.join(test_dir, "testrecv.py")
|
2016-04-23 10:45:01 -07:00
|
|
|
[w] = services.start_cluster(return_drivers=True, num_workers_per_objstore=3, worker_path=test_path)
|
2016-04-08 12:58:08 -07:00
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
objref = w.remote_call("test_functions.test_alias_f", [])
|
2016-04-08 12:58:08 -07:00
|
|
|
self.assertTrue(np.alltrue(orchpy.pull(objref[0], w) == np.ones([3, 4, 5])))
|
2016-04-18 13:05:36 -07:00
|
|
|
objref = w.remote_call("test_functions.test_alias_g", [])
|
2016-04-08 12:58:08 -07:00
|
|
|
self.assertTrue(np.alltrue(orchpy.pull(objref[0], w) == np.ones([3, 4, 5])))
|
2016-04-18 13:05:36 -07:00
|
|
|
objref = w.remote_call("test_functions.test_alias_h", [])
|
2016-04-08 12:58:08 -07:00
|
|
|
self.assertTrue(np.alltrue(orchpy.pull(objref[0], w) == np.ones([3, 4, 5])))
|
2016-04-05 00:34:23 -07:00
|
|
|
|
2016-04-18 13:05:36 -07:00
|
|
|
services.cleanup()
|
|
|
|
|
|
|
|
class ReferenceCountingTest(unittest.TestCase):
|
|
|
|
|
|
|
|
def testDeallocation(self):
|
|
|
|
test_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
test_path = os.path.join(test_dir, "testrecv.py")
|
2016-04-23 10:45:01 -07:00
|
|
|
services.start_cluster(return_drivers=False, num_workers_per_objstore=3, worker_path=test_path)
|
2016-04-18 13:05:36 -07:00
|
|
|
|
|
|
|
x = test_functions.test_alias_f()
|
|
|
|
orchpy.pull(x)
|
|
|
|
time.sleep(0.1)
|
|
|
|
objref_val = x.val
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val] == 1)
|
|
|
|
|
|
|
|
del x
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val] == -1) # -1 indicates deallocated
|
|
|
|
|
|
|
|
y = test_functions.test_alias_h()
|
|
|
|
orchpy.pull(y)
|
|
|
|
time.sleep(0.1)
|
|
|
|
objref_val = y.val
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [1, 0, 0])
|
|
|
|
|
|
|
|
del y
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, -1, -1])
|
|
|
|
|
|
|
|
z = dist.zeros([dist.BLOCK_SIZE, 2 * dist.BLOCK_SIZE], "float")
|
|
|
|
time.sleep(0.1)
|
|
|
|
objref_val = z.val
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [1, 1, 1])
|
|
|
|
|
|
|
|
del z
|
|
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, -1, -1])
|
|
|
|
|
|
|
|
x = single.zeros([10, 10], "float")
|
|
|
|
y = single.zeros([10, 10], "float")
|
|
|
|
z = single.dot(x, y)
|
|
|
|
objref_val = x.val
|
|
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [1, 1, 1])
|
|
|
|
|
|
|
|
del x
|
|
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, 1, 1])
|
|
|
|
del y
|
|
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, -1, 1])
|
|
|
|
del z
|
|
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(orchpy.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, -1, -1])
|
|
|
|
|
|
|
|
services.cleanup()
|
|
|
|
|
2016-02-22 13:55:06 -08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|