2016-12-11 12:25:31 -08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
2016-11-18 19:57:51 -08:00
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import redis
|
|
|
|
import signal
|
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
import threading
|
|
|
|
import time
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import global_scheduler
|
|
|
|
import photon
|
|
|
|
import plasma
|
|
|
|
|
|
|
|
USE_VALGRIND = False
|
|
|
|
PLASMA_STORE_MEMORY = 1000000000
|
|
|
|
ID_SIZE = 20
|
|
|
|
|
|
|
|
# These constants must match the schedulign state enum in task.h.
|
|
|
|
TASK_STATUS_WAITING = 1
|
|
|
|
TASK_STATUS_SCHEDULED = 2
|
|
|
|
TASK_STATUS_RUNNING = 4
|
|
|
|
TASK_STATUS_DONE = 8
|
|
|
|
|
|
|
|
def random_object_id():
|
|
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
|
|
|
|
def random_task_id():
|
|
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
|
|
|
|
def random_function_id():
|
|
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
|
|
|
|
def new_port():
|
|
|
|
return random.randint(10000, 65535)
|
|
|
|
|
|
|
|
class TestGlobalScheduler(unittest.TestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
# Start a Redis server.
|
2016-11-21 15:02:40 -08:00
|
|
|
redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../common/thirdparty/redis/src/redis-server")
|
2016-11-18 19:57:51 -08:00
|
|
|
node_ip_address = "127.0.0.1"
|
|
|
|
redis_port = new_port()
|
|
|
|
redis_address = "{}:{}".format(node_ip_address, redis_port)
|
|
|
|
self.redis_process = subprocess.Popen([redis_path, "--port", str(redis_port), "--loglevel", "warning"])
|
|
|
|
time.sleep(0.1)
|
|
|
|
# Create a Redis client.
|
|
|
|
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
|
|
|
|
# Start the global scheduler.
|
2016-12-04 17:08:16 -08:00
|
|
|
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
|
2016-11-18 19:57:51 -08:00
|
|
|
# Start the Plasma store.
|
|
|
|
plasma_store_name, self.p2 = plasma.start_plasma_store()
|
2016-12-04 17:08:16 -08:00
|
|
|
# Start the Plasma manager.
|
2016-12-13 17:21:38 -08:00
|
|
|
plasma_manager_name, self.p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
|
|
|
|
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
2016-11-18 19:57:51 -08:00
|
|
|
# Start the local scheduler.
|
2016-12-13 17:21:38 -08:00
|
|
|
local_scheduler_name, self.p4 = photon.start_local_scheduler(plasma_store_name, plasma_manager_name=plasma_manager_name, plasma_address=plasma_address, redis_address=redis_address)
|
2016-11-18 19:57:51 -08:00
|
|
|
# Connect to the scheduler.
|
|
|
|
self.photon_client = photon.PhotonClient(local_scheduler_name)
|
|
|
|
|
|
|
|
def tearDown(self):
|
2016-12-09 13:04:08 -08:00
|
|
|
# Check that the processes are still alive.
|
|
|
|
self.assertEqual(self.p1.poll(), None)
|
|
|
|
self.assertEqual(self.p2.poll(), None)
|
|
|
|
self.assertEqual(self.p3.poll(), None)
|
|
|
|
self.assertEqual(self.p4.poll(), None)
|
|
|
|
self.assertEqual(self.redis_process.poll(), None)
|
|
|
|
|
2016-11-18 19:57:51 -08:00
|
|
|
# Kill the global scheduler.
|
|
|
|
if USE_VALGRIND:
|
|
|
|
self.p1.send_signal(signal.SIGTERM)
|
|
|
|
self.p1.wait()
|
|
|
|
os._exit(self.p1.returncode)
|
|
|
|
else:
|
|
|
|
self.p1.kill()
|
|
|
|
self.p2.kill()
|
|
|
|
self.p3.kill()
|
2016-12-04 17:08:16 -08:00
|
|
|
self.p4.kill()
|
2016-11-18 19:57:51 -08:00
|
|
|
# Kill Redis. In the event that we are using valgrind, this needs to happen
|
|
|
|
# after we kill the global scheduler.
|
|
|
|
self.redis_process.kill()
|
|
|
|
|
|
|
|
def test_redis_contents(self):
|
2016-12-04 17:08:16 -08:00
|
|
|
# There should be two db clients, the global scheduler, the local scheduler,
|
|
|
|
# and the plasma manager.
|
|
|
|
self.assertEqual(len(self.redis_client.keys("db_clients*")), 3)
|
2016-11-18 19:57:51 -08:00
|
|
|
# There should not be anything else in Redis yet.
|
2016-12-04 17:08:16 -08:00
|
|
|
self.assertEqual(len(self.redis_client.keys("*")), 3)
|
2016-11-18 19:57:51 -08:00
|
|
|
|
|
|
|
# Submit a task to Redis.
|
|
|
|
task = photon.Task(random_function_id(), [], 0, random_task_id(), 0)
|
|
|
|
self.photon_client.submit(task)
|
|
|
|
# There should now be a task in Redis, and it should get assigned to the
|
|
|
|
# local scheduler
|
|
|
|
while True:
|
|
|
|
task_entries = self.redis_client.keys("task*")
|
|
|
|
self.assertLessEqual(len(task_entries), 1)
|
|
|
|
if len(task_entries) == 1:
|
|
|
|
task_contents = self.redis_client.hgetall(task_entries[0])
|
|
|
|
task_status = int(task_contents["state"])
|
|
|
|
self.assertTrue(task_status in [TASK_STATUS_WAITING, TASK_STATUS_SCHEDULED])
|
|
|
|
if task_status == TASK_STATUS_SCHEDULED:
|
|
|
|
break
|
|
|
|
print("The task has not been scheduled yet, trying again.")
|
|
|
|
|
|
|
|
# Submit a bunch of tasks to Redis.
|
|
|
|
num_tasks = 1000
|
|
|
|
for _ in range(num_tasks):
|
|
|
|
task = photon.Task(random_function_id(), [], 0, random_task_id(), 0)
|
|
|
|
self.photon_client.submit(task)
|
|
|
|
# Check that there are the correct number of tasks in Redis and that they
|
|
|
|
# all get assigned to the local scheduler.
|
|
|
|
while True:
|
|
|
|
task_entries = self.redis_client.keys("task*")
|
|
|
|
self.assertLessEqual(len(task_entries), num_tasks + 1)
|
|
|
|
if len(task_entries) == num_tasks + 1:
|
|
|
|
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
|
|
|
|
task_statuses = [int(contents["state"]) for contents in task_contents]
|
|
|
|
self.assertTrue(all([status in [TASK_STATUS_WAITING, TASK_STATUS_SCHEDULED] for status in task_statuses]))
|
|
|
|
if all([status == TASK_STATUS_SCHEDULED for status in task_statuses]):
|
|
|
|
break
|
|
|
|
print("The tasks have not been scheduled yet, trying again.")
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if len(sys.argv) > 1:
|
|
|
|
# pop the argument so we don't mess with unittest's own argument parser
|
|
|
|
arg = sys.argv.pop()
|
|
|
|
if arg == "valgrind":
|
|
|
|
USE_VALGRIND = True
|
|
|
|
print("Using valgrind for tests")
|
|
|
|
unittest.main(verbosity=2)
|