ray/test/runtest.py
2016-03-17 22:32:31 -07:00

247 lines
7.8 KiB
Python

import unittest
import orchpy
import orchpy.serialization as serialization
import orchpy.services as services
import orchpy.worker as worker
import numpy as np
import time
import subprocess32 as subprocess
import os
import arrays.single as single
from google.protobuf.text_format import *
from grpc.beta import implementations
import orchestra_pb2
import types_pb2
IP_ADDRESS = "127.0.0.1"
TIMEOUT_SECONDS = 5
def connect_to_scheduler(host, port):
channel = implementations.insecure_channel(host, port)
return orchestra_pb2.beta_create_Scheduler_stub(channel)
def connect_to_objstore(host, port):
channel = implementations.insecure_channel(host, port)
return orchestra_pb2.beta_create_ObjStore_stub(channel)
def address(host, port):
return host + ":" + str(port)
scheduler_port_counter = 0
def new_scheduler_port():
global scheduler_port_counter
scheduler_port_counter += 1
return 10000 + scheduler_port_counter
worker_port_counter = 0
def new_worker_port():
global worker_port_counter
worker_port_counter += 1
return 40000 + worker_port_counter
objstore_port_counter = 0
def new_objstore_port():
global objstore_port_counter
objstore_port_counter += 1
return 20000 + objstore_port_counter
class SerializationTest(unittest.TestCase):
def roundTripTest(self, data):
serialized = serialization.serialize(data)
result = serialization.deserialize(serialized)
self.assertEqual(data, result)
def numpyTypeTest(self, typ):
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
b = serialization.serialize(a)
c = serialization.deserialize(b)
self.assertTrue((a == c).all())
def testSerialize(self):
self.roundTripTest([1, "hello", 3.0])
self.roundTripTest(42)
self.roundTripTest("hello world")
self.roundTripTest(42.0)
self.roundTripTest((1.0, "hi"))
a = np.zeros((100, 100))
res = serialization.serialize(a)
b = serialization.deserialize(res)
self.assertTrue((a == b).all())
self.numpyTypeTest('int8')
self.numpyTypeTest('uint8')
# self.numpyTypeTest('int16') # TODO(pcm): implement this
# self.numpyTypeTest('int32') # TODO(pcm): implement this
self.numpyTypeTest('float32')
self.numpyTypeTest('float64')
a = np.array([[orchpy.lib.ObjRef(0), orchpy.lib.ObjRef(1)], [orchpy.lib.ObjRef(41), orchpy.lib.ObjRef(42)]])
capsule = serialization.serialize(a)
result = serialization.deserialize(capsule)
self.assertTrue((a == result).all())
class OrchPyLibTest(unittest.TestCase):
def testOrchPyLib(self):
scheduler_port = new_scheduler_port()
objstore_port = new_objstore_port()
worker_port = new_worker_port()
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
time.sleep(0.1)
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
time.sleep(0.2)
w = worker.Worker()
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker_port), w)
w.put_object(orchpy.lib.ObjRef(0), 'hello world')
result = w.get_object(orchpy.lib.ObjRef(0))
self.assertEqual(result, 'hello world')
services.cleanup()
class ObjStoreTest(unittest.TestCase):
# Test setting up object stores, transfering data between them and retrieving data to a client
def testObjStore(self):
scheduler_port = new_scheduler_port()
objstore1_port = new_objstore_port()
objstore2_port = new_objstore_port()
worker1_port = new_worker_port()
worker2_port = new_worker_port()
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
time.sleep(0.1)
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port))
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port))
time.sleep(0.2)
scheduler_stub = connect_to_scheduler(IP_ADDRESS, scheduler_port)
objstore1_stub = connect_to_objstore(IP_ADDRESS, objstore1_port)
objstore2_stub = connect_to_objstore(IP_ADDRESS, objstore2_port)
worker1 = worker.Worker()
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore1_port), address(IP_ADDRESS, worker1_port), worker1)
worker2 = worker.Worker()
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore2_port), address(IP_ADDRESS, worker2_port), worker2)
# pushing and pulling an object shouldn't change it
for data in ["h", "h" * 10000, 0, 0.0]:
objref = orchpy.push(data, worker1)
result = orchpy.pull(objref, worker1)
self.assertEqual(result, data)
# pushing an object, shipping it to another worker, and pulling it shouldn't change it
# for data in ["h", "h" * 10000, 0, 0.0]:
# objref = worker.push(data, worker1)
# response = objstore1_stub.DeliverObj(orchestra_pb2.DeliverObjRequest(objref=objref.val, objstore_address=address(IP_ADDRESS, objstore2_port)), TIMEOUT_SECONDS)
# result = worker.pull(objref, worker2)
# self.assertEqual(result, data)
services.cleanup()
class SchedulerTest(unittest.TestCase):
def testCall(self):
scheduler_port = new_scheduler_port()
objstore_port = new_objstore_port()
worker1_port = new_worker_port()
worker2_port = new_worker_port()
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
time.sleep(0.1)
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
time.sleep(0.2)
scheduler_stub = connect_to_scheduler(IP_ADDRESS, scheduler_port)
objstore_stub = connect_to_objstore(IP_ADDRESS, objstore_port)
time.sleep(0.2)
worker1 = worker.Worker()
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
test_dir = os.path.dirname(os.path.abspath(__file__))
test_path = os.path.join(test_dir, "testrecv.py")
services.start_worker(test_path, address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker2_port))
time.sleep(0.2)
value_before = "test_string"
objref = worker1.remote_call("__main__.print_string", [value_before])
time.sleep(0.2)
value_after = orchpy.pull(objref[0], worker1)
self.assertEqual(value_before, value_after)
time.sleep(0.1)
reply = scheduler_stub.SchedulerDebugInfo(orchestra_pb2.SchedulerDebugInfoRequest(), TIMEOUT_SECONDS)
services.cleanup()
class WorkerTest(unittest.TestCase):
def testPushPull(self):
scheduler_port = new_scheduler_port()
objstore_port = new_objstore_port()
worker1_port = new_worker_port()
services.start_scheduler(address(IP_ADDRESS, scheduler_port))
time.sleep(0.1)
services.start_objstore(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port))
time.sleep(0.2)
worker1 = worker.Worker()
orchpy.connect(address(IP_ADDRESS, scheduler_port), address(IP_ADDRESS, objstore_port), address(IP_ADDRESS, worker1_port), worker1)
for i in range(100):
value_before = i * 10 ** 6
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = i * 10 ** 6 * 1.0
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = "h" * i
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
for i in range(100):
value_before = [1] * i
objref = orchpy.push(value_before, worker1)
value_after = orchpy.pull(objref, worker1)
self.assertEqual(value_before, value_after)
services.cleanup()
if __name__ == '__main__':
unittest.main()