diff --git a/lib/orchpy/orchpy/services.py b/lib/orchpy/orchpy/services.py index 99a8b217b..4153fad74 100644 --- a/lib/orchpy/orchpy/services.py +++ b/lib/orchpy/orchpy/services.py @@ -3,10 +3,37 @@ import os import atexit import time +import orchpy +import orchpy.worker as worker + _services_path = os.path.dirname(os.path.abspath(__file__)) all_processes = [] +IP_ADDRESS = "127.0.0.1" +TIMEOUT_SECONDS = 5 + +def address(host, port): + return host + ":" + str(port) + +scheduler_port_counter = 0 +def new_scheduler_port(): + global scheduler_port_counter + scheduler_port_counter += 1 + return 10000 + scheduler_port_counter + +worker_port_counter = 0 +def new_worker_port(): + global worker_port_counter + worker_port_counter += 1 + return 40000 + worker_port_counter + +objstore_port_counter = 0 +def new_objstore_port(): + global objstore_port_counter + objstore_port_counter += 1 + return 20000 + objstore_port_counter + def cleanup(): global all_processes for p, address in all_processes: @@ -45,3 +72,20 @@ def start_worker(test_path, scheduler_address, objstore_address, worker_address) "--objstore-address=" + objstore_address, "--worker-address=" + worker_address]) all_processes.append((p, worker_address)) + +def start_cluster(driver_worker=None, num_workers=0, worker_path=None): + if num_workers > 0 and worker_path is None: + raise Exception("Attempting to start a cluster with some workers, but `worker_path` is None.") + scheduler_address = address(IP_ADDRESS, new_scheduler_port()) + objstore_address = address(IP_ADDRESS, new_objstore_port()) + start_scheduler(scheduler_address) + time.sleep(0.1) + start_objstore(scheduler_address, objstore_address) + time.sleep(0.2) + if driver_worker is not None: + orchpy.connect(scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port()), driver_worker) + else: + orchpy.connect(scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port())) + for _ in range(num_workers): + start_worker(worker_path, scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port())) + time.sleep(0.3) diff --git a/test/arrays_test.py b/test/arrays_test.py index f2b6e972d..295bddf09 100644 --- a/test/arrays_test.py +++ b/test/arrays_test.py @@ -16,61 +16,12 @@ from grpc.beta import implementations import orchestra_pb2 import types_pb2 -IP_ADDRESS = "127.0.0.1" -TIMEOUT_SECONDS = 5 - -def connect_to_scheduler(host, port): - channel = implementations.insecure_channel(host, port) - return orchestra_pb2.beta_create_Scheduler_stub(channel) - -def connect_to_objstore(host, port): - channel = implementations.insecure_channel(host, port) - return orchestra_pb2.beta_create_ObjStore_stub(channel) - -def address(host, port): - return host + ":" + str(port) - -scheduler_port_counter = 0 -def new_scheduler_port(): - global scheduler_port_counter - scheduler_port_counter += 1 - return 10000 + scheduler_port_counter - -worker_port_counter = 0 -def new_worker_port(): - global worker_port_counter - worker_port_counter += 1 - return 40000 + worker_port_counter - -objstore_port_counter = 0 -def new_objstore_port(): - global objstore_port_counter - objstore_port_counter += 1 - return 20000 + objstore_port_counter - class ArraysSingleTest(unittest.TestCase): def testMethods(self): - scheduler_port = new_scheduler_port() - objstore_port = new_objstore_port() - worker1_port = new_worker_port() - worker2_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) - - time.sleep(0.2) - - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port)) - test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") - services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port)) - - time.sleep(0.2) + services.start_cluster(num_workers=1, worker_path=test_path) # test eye ref = single.eye(3, "float") @@ -110,26 +61,9 @@ class ArraysDistTest(unittest.TestCase): self.assertEqual(x.objrefs[0, 0, 0].val, y.objrefs[0, 0, 0].val) def testAssemble(self): - scheduler_port = new_scheduler_port() - objstore_port = new_objstore_port() - worker1_port = new_worker_port() - worker2_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) - - time.sleep(0.2) - - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port)) - test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") - services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port)) - - time.sleep(0.2) + services.start_cluster(num_workers=1, worker_path=test_path) a = single.ones([dist.BLOCK_SIZE, dist.BLOCK_SIZE], "float") b = single.zeros([dist.BLOCK_SIZE, dist.BLOCK_SIZE], "float") @@ -140,30 +74,9 @@ class ArraysDistTest(unittest.TestCase): services.cleanup() def testMethods(self): - scheduler_port = new_scheduler_port() - objstore_port = new_objstore_port() - worker1_port = new_worker_port() - worker2_port = new_worker_port() - worker3_port = new_worker_port() - worker4_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) - - time.sleep(0.2) - - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port)) - test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") - services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port)) - services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker3_port)) - services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker4_port)) - - time.sleep(0.2) + services.start_cluster(num_workers=3, worker_path=test_path) x = dist.zeros([9, 25, 51], "float") self.assertTrue(np.alltrue(orchpy.pull(dist.assemble(x)) == np.zeros([9, 25, 51]))) diff --git a/test/runtest.py b/test/runtest.py index 53f4e88b1..284155fca 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -8,46 +8,11 @@ import time import subprocess32 as subprocess import os -import arrays.single as single - from google.protobuf.text_format import * -from grpc.beta import implementations import orchestra_pb2 import types_pb2 -IP_ADDRESS = "127.0.0.1" -TIMEOUT_SECONDS = 5 - -def connect_to_scheduler(host, port): - channel = implementations.insecure_channel(host, port) - return orchestra_pb2.beta_create_Scheduler_stub(channel) - -def connect_to_objstore(host, port): - channel = implementations.insecure_channel(host, port) - return orchestra_pb2.beta_create_ObjStore_stub(channel) - -def address(host, port): - return host + ":" + str(port) - -scheduler_port_counter = 0 -def new_scheduler_port(): - global scheduler_port_counter - scheduler_port_counter += 1 - return 10000 + scheduler_port_counter - -worker_port_counter = 0 -def new_worker_port(): - global worker_port_counter - worker_port_counter += 1 - return 40000 + worker_port_counter - -objstore_port_counter = 0 -def new_objstore_port(): - global objstore_port_counter - objstore_port_counter += 1 - return 20000 + objstore_port_counter - class SerializationTest(unittest.TestCase): def roundTripTest(self, data): @@ -91,21 +56,8 @@ class SerializationTest(unittest.TestCase): class OrchPyLibTest(unittest.TestCase): def testOrchPyLib(self): - scheduler_port = new_scheduler_port() - objstore_port = new_objstore_port() - worker_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) - - time.sleep(0.2) - w = worker.Worker() - - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w) + services.start_cluster(driver_worker=w) w.put_object(orchpy.lib.ObjRef(0), 'hello world') result = w.get_object(orchpy.lib.ObjRef(0)) @@ -118,35 +70,13 @@ class ObjStoreTest(unittest.TestCase): # Test setting up object stores, transfering data between them and retrieving data to a client def testObjStore(self): - scheduler_port = new_scheduler_port() - objstore1_port = new_objstore_port() - objstore2_port = new_objstore_port() - worker1_port = new_worker_port() - worker2_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port)) - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port)) - - time.sleep(0.2) - - scheduler_stub = connect_to_scheduler(IP_ADDRESS, scheduler_port) - objstore1_stub = connect_to_objstore(IP_ADDRESS, objstore1_port) - objstore2_stub = connect_to_objstore(IP_ADDRESS, objstore2_port) - - worker1 = worker.Worker() - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1) - - worker2 = worker.Worker() - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2) + w = worker.Worker() + services.start_cluster(driver_worker=w) # pushing and pulling an object shouldn't change it for data in ["h", "h" * 10000, 0, 0.0]: - objref = orchpy.push(data, worker1) - result = orchpy.pull(objref, worker1) + objref = orchpy.push(data, w) + result = orchpy.pull(objref, w) self.assertEqual(result, data) # pushing an object, shipping it to another worker, and pulling it shouldn't change it @@ -161,90 +91,70 @@ class ObjStoreTest(unittest.TestCase): class SchedulerTest(unittest.TestCase): def testCall(self): - scheduler_port = new_scheduler_port() - objstore_port = new_objstore_port() - worker1_port = new_worker_port() - worker2_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) - - time.sleep(0.2) - - scheduler_stub = connect_to_scheduler(IP_ADDRESS, scheduler_port) - objstore_stub = connect_to_objstore(IP_ADDRESS, objstore_port) - - time.sleep(0.2) - - worker1 = worker.Worker() - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) - test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") - services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port)) - - time.sleep(0.2) + w = worker.Worker() + services.start_cluster(driver_worker=w, num_workers=1, worker_path=test_path) value_before = "test_string" - objref = worker1.remote_call("__main__.print_string", [value_before]) + objref = w.remote_call("__main__.print_string", [value_before]) time.sleep(0.2) - value_after = orchpy.pull(objref[0], worker1) + value_after = orchpy.pull(objref[0], w) self.assertEqual(value_before, value_after) time.sleep(0.1) - reply = scheduler_stub.SchedulerDebugInfo(orchestra_pb2.SchedulerDebugInfoRequest(), TIMEOUT_SECONDS) - services.cleanup() class WorkerTest(unittest.TestCase): def testPushPull(self): - scheduler_port = new_scheduler_port() - objstore_port = new_objstore_port() - worker1_port = new_worker_port() - - services.start_scheduler(address(IP_ADDRESS, scheduler_port)) - - time.sleep(0.1) - - services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port)) - - time.sleep(0.2) - - worker1 = worker.Worker() - orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) + w = worker.Worker() + services.start_cluster(driver_worker=w) for i in range(100): value_before = i * 10 ** 6 - objref = orchpy.push(value_before, worker1) - value_after = orchpy.pull(objref, worker1) + objref = orchpy.push(value_before, w) + value_after = orchpy.pull(objref, w) self.assertEqual(value_before, value_after) for i in range(100): value_before = i * 10 ** 6 * 1.0 - objref = orchpy.push(value_before, worker1) - value_after = orchpy.pull(objref, worker1) + objref = orchpy.push(value_before, w) + value_after = orchpy.pull(objref, w) self.assertEqual(value_before, value_after) for i in range(100): value_before = "h" * i - objref = orchpy.push(value_before, worker1) - value_after = orchpy.pull(objref, worker1) + objref = orchpy.push(value_before, w) + value_after = orchpy.pull(objref, w) self.assertEqual(value_before, value_after) for i in range(100): value_before = [1] * i - objref = orchpy.push(value_before, worker1) - value_after = orchpy.pull(objref, worker1) + objref = orchpy.push(value_before, w) + value_after = orchpy.pull(objref, w) self.assertEqual(value_before, value_after) services.cleanup() +""" +class APITest(unittest.TestCase): + + def testObjRefAliasing(self): + services.start_scheduler(address(IP_ADDRESS, new_scheduler_port())) + time.sleep(0.1) + services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, new_objstore_port())) + time.sleep(0.2) + worker1 = worker.Worker() + orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, new_worker_port()), worker1) + test_dir = os.path.dirname(os.path.abspath(__file__)) + test_path = os.path.join(test_dir, "testrecv.py") + services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, new_worker_port())) +""" + + if __name__ == '__main__': unittest.main() diff --git a/test/testrecv.py b/test/testrecv.py index 7768000cf..3bac951d0 100644 --- a/test/testrecv.py +++ b/test/testrecv.py @@ -1,5 +1,6 @@ import sys import argparse +import numpy as np import arrays.single as single import arrays.dist as dist @@ -13,6 +14,15 @@ parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address") +@orchpy.distributed([], [np.ndarray]) +def test_alias_f(): + return np.ones([3, 4, 5]) + +@orchpy.distributed([], [np.ndarray]) +def test_alias_g(): + return f() + + @orchpy.distributed([str], [str]) def print_string(string): print "called print_string with", string