mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
commit
2535d267ab
5 changed files with 94 additions and 55 deletions
|
@ -9,42 +9,39 @@ all_processes = []
|
|||
|
||||
def cleanup():
|
||||
global all_processes
|
||||
for p, port in all_processes:
|
||||
for p, address in all_processes:
|
||||
if p.poll() is not None: # process has already terminated
|
||||
print "Process at port " + str(port) + " has already terminated."
|
||||
print "Process at address " + address + " has already terminated."
|
||||
continue
|
||||
print "Attempting to kill process at port " + str(port) + "."
|
||||
print "Attempting to kill process at address " + address + "."
|
||||
p.kill()
|
||||
time.sleep(0.05) # is this necessary?
|
||||
if p.poll() is not None:
|
||||
print "Successfully killed process at port " + str(port) + "."
|
||||
print "Successfully killed process at address " + address + "."
|
||||
continue
|
||||
print "Kill attempt failed, attempting to terminate process at port " + str(port) + "."
|
||||
print "Kill attempt failed, attempting to terminate process at address " + address + "."
|
||||
p.terminate()
|
||||
time.sleep(0.05) # is this necessary?
|
||||
if p.poll is not None:
|
||||
print "Successfully terminated process at port " + str(port) + "."
|
||||
print "Successfully terminated process at address " + address + "."
|
||||
continue
|
||||
print "Termination attempt failed, giving up."
|
||||
all_processes = []
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
def start_scheduler(host, port):
|
||||
scheduler_address = host + ":" + str(port)
|
||||
p = subprocess.Popen([os.path.join(_services_path, "scheduler"), str(scheduler_address)])
|
||||
all_processes.append((p, port))
|
||||
def start_scheduler(scheduler_address):
|
||||
p = subprocess.Popen([os.path.join(_services_path, "scheduler"), scheduler_address])
|
||||
all_processes.append((p, scheduler_address))
|
||||
|
||||
def start_objstore(host, port):
|
||||
objstore_address = host + ":" + str(port)
|
||||
p = subprocess.Popen([os.path.join(_services_path, "objstore"), str(objstore_address)])
|
||||
all_processes.append((p, port))
|
||||
def start_objstore(objstore_address):
|
||||
p = subprocess.Popen([os.path.join(_services_path, "objstore"), objstore_address])
|
||||
all_processes.append((p, objstore_address))
|
||||
|
||||
def start_worker(test_path, host, scheduler_port, worker_port, objstore_port):
|
||||
def start_worker(test_path, scheduler_address, objstore_address, worker_address):
|
||||
p = subprocess.Popen(["python",
|
||||
test_path,
|
||||
"--ip_address=" + host,
|
||||
"--scheduler_port=" + str(scheduler_port),
|
||||
"--objstore_port=" + str(objstore_port),
|
||||
"--worker_port=" + str(worker_port)])
|
||||
all_processes.append((p, worker_port))
|
||||
"--scheduler-address=" + scheduler_address,
|
||||
"--objstore-address=" + objstore_address,
|
||||
"--worker-address=" + worker_address])
|
||||
all_processes.append((p, worker_address))
|
||||
|
|
|
@ -83,24 +83,24 @@ def distributed(arg_types, return_types, worker=global_worker):
|
|||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
arguments = []
|
||||
# check the number of args
|
||||
if len(args) != len(function.types) and function.types[-1] is not None:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.types), len(args)))
|
||||
elif len(args) < len(function.types) - 1 and function.types[-1] is None:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.types) - 1, len(args)))
|
||||
if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
print "Pulling argument {} for function {}.".format(i, function.__name__)
|
||||
if i < len(function.types) - 1:
|
||||
expected_type = function.types[i]
|
||||
elif i == len(function.types) - 1 and function.types[-1] is not None:
|
||||
expected_type = function.types[-1]
|
||||
elif function.types[-1] is None and len(function.types > 1):
|
||||
expected_type = function.types[-2]
|
||||
if i < len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif i == len(function.arg_types) - 1 and function.arg_types[-1] is not None:
|
||||
expected_type = function.arg_types[-1]
|
||||
elif function.arg_types[-1] is None and len(function.arg_types > 1):
|
||||
expected_type = function.arg_types[-2]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
argument = worker.get_object(arg) if type(arg) == orchpy.ObjRef else arg
|
||||
if type(arg) == orchpy.ObjRef:
|
||||
argument = worker.get_object(arg) if type(arg) == orchpy.lib.ObjRef else arg
|
||||
if type(arg) == orchpy.lib.ObjRef:
|
||||
# get the object from the local object store
|
||||
# TODO(rkn): Do we know that it is already there? Maybe we should call pull(arg, worker).
|
||||
argument = worker.get_object(arg)
|
||||
|
|
|
@ -71,9 +71,9 @@ class ObjStoreTest(unittest.TestCase):
|
|||
worker1_port = new_worker_port()
|
||||
worker2_port = new_worker_port()
|
||||
|
||||
services.start_scheduler(IP_ADDRESS, scheduler_port)
|
||||
services.start_objstore(IP_ADDRESS, objstore1_port)
|
||||
services.start_objstore(IP_ADDRESS, objstore2_port)
|
||||
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
|
||||
services.start_objstore(address(IP_ADDRESS, objstore1_port))
|
||||
services.start_objstore(address(IP_ADDRESS, objstore2_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
|
@ -110,8 +110,8 @@ class SchedulerTest(unittest.TestCase):
|
|||
worker1_port = new_worker_port()
|
||||
worker2_port = new_worker_port()
|
||||
|
||||
services.start_scheduler(IP_ADDRESS, scheduler_port)
|
||||
services.start_objstore(IP_ADDRESS, objstore_port)
|
||||
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
|
||||
services.start_objstore(address(IP_ADDRESS, objstore_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
|
@ -125,11 +125,17 @@ class SchedulerTest(unittest.TestCase):
|
|||
|
||||
test_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_path = os.path.join(test_dir, "testrecv.py")
|
||||
services.start_worker(test_path, IP_ADDRESS, scheduler_port, worker2_port, objstore_port)
|
||||
services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
worker1.remote_call("print_string", ["hi"])
|
||||
value_before = "test_string"
|
||||
objref = worker1.remote_call("__main__.print_string", [value_before])
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# value_after = worker.pull(objref, worker1)
|
||||
# self.assertEqual(value_before, value_after)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
@ -137,5 +143,46 @@ class SchedulerTest(unittest.TestCase):
|
|||
|
||||
services.cleanup()
|
||||
|
||||
class WorkerTest(unittest.TestCase):
|
||||
|
||||
def testPushPull(self):
|
||||
scheduler_port = new_scheduler_port()
|
||||
objstore_port = new_objstore_port()
|
||||
worker1_port = new_worker_port()
|
||||
|
||||
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
|
||||
services.start_objstore(address(IP_ADDRESS, objstore_port))
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
worker1 = worker.Worker()
|
||||
worker.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6 * 1.0
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = "h" * i
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
for i in range(100):
|
||||
value_before = [1] * i
|
||||
objref = worker.push(value_before, worker1)
|
||||
value_after = worker.pull(objref, worker1)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
services.cleanup()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -8,10 +8,9 @@ import orchestra_pb2
|
|||
import types_pb2
|
||||
|
||||
parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
|
||||
parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore")
|
||||
parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port")
|
||||
parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port")
|
||||
parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port")
|
||||
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
|
||||
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
|
||||
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")
|
||||
|
||||
@worker.distributed([str], [str])
|
||||
def print_string(string):
|
||||
|
@ -24,8 +23,8 @@ def print_string(string):
|
|||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
def connect_to_scheduler(host, port):
|
||||
channel = implementations.insecure_channel(host, port)
|
||||
def connect_to_scheduler(address):
|
||||
channel = implementations.insecure_channel(address)
|
||||
return orchestra_pb2.beta_create_Scheduler_stub(channel)
|
||||
|
||||
def address(host, port):
|
||||
|
@ -33,7 +32,7 @@ def address(host, port):
|
|||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
scheduler_stub = connect_to_scheduler(args.ip_address, args.scheduler_port)
|
||||
worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port))
|
||||
scheduler_stub = connect_to_scheduler(args.scheduler_address)
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address))
|
||||
import IPython
|
||||
IPython.embed()
|
||||
|
|
|
@ -5,10 +5,9 @@ import orchpy.services as services
|
|||
import orchpy.worker as worker
|
||||
|
||||
parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
|
||||
parser.add_argument("--ip_address", default="127.0.0.1", help="the IP address to use for both the scheduler and objstore")
|
||||
parser.add_argument("--scheduler_port", default=10001, type=int, help="the scheduler's port")
|
||||
parser.add_argument("--objstore_port", default=20001, type=int, help="the objstore's port")
|
||||
parser.add_argument("--worker_port", default=40001, type=int, help="the worker's port")
|
||||
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
|
||||
parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address")
|
||||
parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address")
|
||||
|
||||
@worker.distributed([str], [str])
|
||||
def print_string(string):
|
||||
|
@ -21,12 +20,9 @@ def print_string(string):
|
|||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
def address(host, port):
|
||||
return host + ":" + str(port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
worker.connect(address(args.ip_address, args.scheduler_port), address(args.ip_address, args.objstore_port), address(args.ip_address, args.worker_port))
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
|
||||
worker.global_worker.register_function(print_string)
|
||||
worker.global_worker.register_function(handle_int)
|
||||
|
|
Loading…
Add table
Reference in a new issue