ray/python/global_scheduler/test/test.py
2017-02-24 11:05:45 -08:00

284 lines
11 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import random
import redis
import signal
import subprocess
import sys
import threading
import time
import unittest
import global_scheduler
import photon
import plasma
from plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
from ray import services
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
ID_SIZE = 20
NUM_CLUSTER_NODES = 2
NIL_ACTOR_ID = 20 * b"\xff"
# These constants must match the scheduling state enum in task.h.
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
# These constants are an implementation detail of ray_redis_module.c, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
def random_driver_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_object_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def new_port():
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, photon)
node_ip_address = "127.0.0.1"
redis_port, self.redis_process = services.start_redis(cleanup=False)
redis_address = services.address(node_ip_address, redis_port)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
self.plasma_clients = []
self.photon_clients = []
for i in range(NUM_CLUSTER_NODES):
# Start the Plasma store. Plasma store name is randomly generated.
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = photon.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=plasma_address,
redis_address=redis_address,
static_resource_list=[10, 0])
# Connect to the scheduler.
photon_client = photon.PhotonClient(local_scheduler_name, NIL_ACTOR_ID)
self.photon_clients.append(photon_client)
self.local_scheduler_pids.append(p4)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
for p2 in self.plasma_store_pids:
self.assertEqual(p2.poll(), None)
for p3 in self.plasma_manager_pids:
self.assertEqual(p3.poll(), None)
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
self.assertEqual(self.redis_process.poll(), None)
# Kill the global scheduler.
if USE_VALGRIND:
self.p1.send_signal(signal.SIGTERM)
self.p1.wait()
if self.p1.returncode != 0:
os._exit(-1)
else:
self.p1.kill()
# Kill local schedulers, plasma managers, and plasma stores.
for p2 in self.local_scheduler_pids:
p2.kill()
for p3 in self.plasma_manager_pids:
p3.kill()
for p4 in self.plasma_store_pids:
p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to happen
# after we kill the global scheduler.
self.redis_process.kill()
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
Iterates over all the client table keys, gets the db_client_id for the
client with client_type matching plasma_manager. Strips the client table
prefix. TODO(atumanov): write a separate function to get all plasma manager
client IDs.
Returns:
The db_client_id if one is found and otherwise None.
"""
db_client_id = None
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
for client_id in client_list:
response = self.redis_client.hget(client_id, b"client_type")
if response == b"plasma_manager":
db_client_id = client_id
break
return db_client_id
def test_task_default_resources(self):
task1 = photon.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
self.assertEqual(task1.required_resources(), [1.0, 0.0])
task2 = photon.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0,
photon.ObjectID(NIL_ACTOR_ID), 0, [1.0, 2.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0])
def test_redis_only_single_task(self):
"""
Tests global scheduler functionality by interacting with Redis and checking
task state transitions in Redis only. TODO(atumanov): implement.
"""
# Check precondition for this test:
# There should be 2n+1 db clients: the global scheduler + one photon and one plasma per node.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id != None)
assert(db_client_id.startswith(b"CL:"))
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to photon.
time.sleep(0.1)
# Submit a task to Redis.
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
self.photon_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to the
# local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_contents = self.redis_client.hgetall(task_entries[0])
task_status = int(task_contents[b"state"])
self.assertTrue(task_status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED])
if task_status == TASK_STATUS_QUEUED:
break
else:
print(task_status)
print("The task has not been scheduled yet, trying again.")
num_retries -= 1
time.sleep(1)
if num_retries <= 0 and task_status != TASK_STATUS_QUEUED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
num_tasks = 1000
for _ in range(num_tasks):
# Create a new object for each task.
data_size = np.random.randint(1 << 20)
metadata_size = np.random.randint(1 << 10)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to yield CPU).
time.sleep(0.010)
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
self.photon_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that they
# all get assigned to the local scheduler.
num_retries = 10
num_tasks_done = 0
while num_retries > 0:
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
task_statuses = [int(contents[b"state"]) for contents in task_contents]
self.assertTrue(all([
status in [TASK_STATUS_WAITING,
TASK_STATUS_SCHEDULED,
TASK_STATUS_QUEUED] for status in task_statuses
]))
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
if num_tasks_done != num_tasks:
# At least one of the tasks failed to schedule.
self.tearDown()
sys.exit(2)
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
def test_integration_many_tasks(self):
# More realistic case: should handle out of order object and task
# notifications.
self.integration_many_tasks_helper(timesync=False)
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)