diff --git a/lib/orchpy/orchpy/services.py b/lib/orchpy/orchpy/services.py index d4fb6a126..6d20c4406 100644 --- a/lib/orchpy/orchpy/services.py +++ b/lib/orchpy/orchpy/services.py @@ -9,42 +9,39 @@ all_processes = [] def cleanup(): global all_processes - for p, port in all_processes: + for p, address in all_processes: if p.poll() is not None: # process has already terminated - print "Process at port " + str(port) + " has already terminated." + print "Process at address " + address + " has already terminated." continue - print "Attempting to kill process at port " + str(port) + "." + print "Attempting to kill process at address " + address + "." p.kill() time.sleep(0.05) # is this necessary? if p.poll() is not None: - print "Successfully killed process at port " + str(port) + "." + print "Successfully killed process at address " + address + "." continue - print "Kill attempt failed, attempting to terminate process at port " + str(port) + "." + print "Kill attempt failed, attempting to terminate process at address " + address + "." p.terminate() time.sleep(0.05) # is this necessary? if p.poll is not None: - print "Successfully terminated process at port " + str(port) + "." + print "Successfully terminated process at address " + address + "." continue print "Termination attempt failed, giving up." all_processes = [] atexit.register(cleanup) -def start_scheduler(host, port): - scheduler_address = host + ":" + str(port) - p = subprocess.Popen([os.path.join(_services_path, "scheduler"), str(scheduler_address)]) - all_processes.append((p, port)) +def start_scheduler(scheduler_address): + p = subprocess.Popen([os.path.join(_services_path, "scheduler"), scheduler_address]) + all_processes.append((p, scheduler_address)) -def start_objstore(host, port): - objstore_address = host + ":" + str(port) - p = subprocess.Popen([os.path.join(_services_path, "objstore"), str(objstore_address)]) - all_processes.append((p, port)) +def start_objstore(objstore_address): + p = subprocess.Popen([os.path.join(_services_path, "objstore"), objstore_address]) + all_processes.append((p, objstore_address)) -def start_worker(test_path, host, scheduler_port, worker_port, objstore_port): +def start_worker(test_path, scheduler_address, objstore_address, worker_address): p = subprocess.Popen(["python", test_path, - "--ip_address=" + host, - "--scheduler_port=" + str(scheduler_port), - "--objstore_port=" + str(objstore_port), - "--worker_port=" + str(worker_port)]) - all_processes.append((p, worker_port)) + "--scheduler-address=" + scheduler_address, + "--objstore-address=" + objstore_address, + "--worker-address=" + worker_address]) + all_processes.append((p, worker_address)) diff --git a/lib/orchpy/orchpy/worker.py b/lib/orchpy/orchpy/worker.py index 71c532170..5b607dcc6 100644 --- a/lib/orchpy/orchpy/worker.py +++ b/lib/orchpy/orchpy/worker.py @@ -83,24 +83,24 @@ def distributed(arg_types, return_types, worker=global_worker): def get_arguments_for_execution(function, args, worker=global_worker): arguments = [] # check the number of args - if len(args) != len(function.types) and function.types[-1] is not None: - raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.types), len(args))) - elif len(args) < len(function.types) - 1 and function.types[-1] is None: - raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.types) - 1, len(args))) + if len(args) != len(function.arg_types) and function.arg_types[-1] is not None: + raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args))) + elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None: + raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args))) for (i, arg) in enumerate(args): print "Pulling argument {} for function {}.".format(i, function.__name__) - if i < len(function.types) - 1: - expected_type = function.types[i] - elif i == len(function.types) - 1 and function.types[-1] is not None: - expected_type = function.types[-1] - elif function.types[-1] is None and len(function.types > 1): - expected_type = function.types[-2] + if i < len(function.arg_types) - 1: + expected_type = function.arg_types[i] + elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None: + expected_type = function.arg_types[-1] + elif function.arg_types[-1] is None and len(function.arg_types > 1): + expected_type = function.arg_types[-2] else: assert False, "This code should be unreachable." - argument = worker.get_object(arg) if type(arg) == orchpy.ObjRef else arg - if type(arg) == orchpy.ObjRef: + argument = worker.get_object(arg) if type(arg) == orchpy.lib.ObjRef else arg + if type(arg) == orchpy.lib.ObjRef: # get the object from the local object store # TODO(rkn): Do we know that it is already there? Maybe we should call pull(arg, worker). argument = worker.get_object(arg) diff --git a/test/runtest.py b/test/runtest.py index cce69a113..4fd6015db 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -71,9 +71,9 @@ class ObjStoreTest(unittest.TestCase): worker1_port = new_worker_port() worker2_port = new_worker_port() - services.start_scheduler(IP_ADDRESS, scheduler_port) - services.start_objstore(IP_ADDRESS, objstore1_port) - services.start_objstore(IP_ADDRESS, objstore2_port) + services.start_scheduler(address(IP_ADDRESS, scheduler_port)) + services.start_objstore(address(IP_ADDRESS, objstore1_port)) + services.start_objstore(address(IP_ADDRESS, objstore2_port)) time.sleep(0.2) @@ -110,8 +110,8 @@ class SchedulerTest(unittest.TestCase): worker1_port = new_worker_port() worker2_port = new_worker_port() - services.start_scheduler(IP_ADDRESS, scheduler_port) - services.start_objstore(IP_ADDRESS, objstore_port) + services.start_scheduler(address(IP_ADDRESS, scheduler_port)) + services.start_objstore(address(IP_ADDRESS, objstore_port)) time.sleep(0.2) @@ -125,11 +125,17 @@ class SchedulerTest(unittest.TestCase): test_dir = os.path.dirname(os.path.abspath(__file__)) test_path = os.path.join(test_dir, "testrecv.py") - services.start_worker(test_path, IP_ADDRESS, scheduler_port, worker2_port, objstore_port) + services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port)) time.sleep(0.2) - worker1.remote_call("print_string", ["hi"]) + value_before = "test_string" + objref = worker1.remote_call("__main__.print_string", [value_before]) + + time.sleep(0.2) + + # value_after = worker.pull(objref, worker1) + # self.assertEqual(value_before, value_after) time.sleep(0.1) @@ -137,5 +143,46 @@ class SchedulerTest(unittest.TestCase): services.cleanup() +class WorkerTest(unittest.TestCase): + + def testPushPull(self): + scheduler_port = new_scheduler_port() + objstore_port = new_objstore_port() + worker1_port = new_worker_port() + + services.start_scheduler(address(IP_ADDRESS, scheduler_port)) + services.start_objstore(address(IP_ADDRESS, objstore_port)) + + time.sleep(0.2) + + worker1 = worker.Worker() + worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1) + + for i in range(100): + value_before = i * 10 ** 6 + objref = worker.push(value_before, worker1) + value_after = worker.pull(objref, worker1) + self.assertEqual(value_before, value_after) + + for i in range(100): + value_before = i * 10 ** 6 * 1.0 + objref = worker.push(value_before, worker1) + value_after = worker.pull(objref, worker1) + self.assertEqual(value_before, value_after) + + for i in range(100): + value_before = "h" * i + objref = worker.push(value_before, worker1) + value_after = worker.pull(objref, worker1) + self.assertEqual(value_before, value_after) + + for i in range(100): + value_before = [1] * i + objref = worker.push(value_before, worker1) + value_after = worker.pull(objref, worker1) + self.assertEqual(value_before, value_after) + + services.cleanup() + if __name__ == '__main__': unittest.main() diff --git a/test/shell.py b/test/shell.py index 5cc683d9d..d9ede1d9b 100644 --- a/test/shell.py +++ b/test/shell.py @@ -8,10 +8,9 @@ import orchestra_pb2 import types_pb2 parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.') -parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore") -parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port") -parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port") -parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port") +parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address") +parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") +parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address") @worker.distributed([str], [str]) def print_string(string): @@ -24,8 +23,8 @@ def print_string(string): def handle_int(a, b): return a + 1, b + 1 -def connect_to_scheduler(host, port): - channel = implementations.insecure_channel(host, port) +def connect_to_scheduler(address): + channel = implementations.insecure_channel(address) return orchestra_pb2.beta_create_Scheduler_stub(channel) def address(host, port): @@ -33,7 +32,7 @@ def address(host, port): if __name__ == '__main__': args = parser.parse_args() - scheduler_stub = connect_to_scheduler(args.ip_address, args.scheduler_port) - worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port)) + scheduler_stub = connect_to_scheduler(args.scheduler_address) + worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)) import IPython IPython.embed() diff --git a/test/testrecv.py b/test/testrecv.py index 759fbf743..22064ac37 100644 --- a/test/testrecv.py +++ b/test/testrecv.py @@ -5,10 +5,9 @@ import orchpy.services as services import orchpy.worker as worker parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.') -parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore") -parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port") -parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port") -parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port") +parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address") +parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") +parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address") @worker.distributed([str], [str]) def print_string(string): @@ -21,12 +20,9 @@ def print_string(string): def handle_int(a, b): return a + 1, b + 1 -def address(host, port): - return host + ":" + str(port) - if __name__ == '__main__': args = parser.parse_args() - worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port)) + worker.connect(args.scheduler_address, args.objstore_address, args.worker_address) worker.global_worker.register_function(print_string) worker.global_worker.register_function(handle_int)