mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
284 lines
11 KiB
Python
284 lines
11 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import redis
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
import unittest
|
|
|
|
import global_scheduler
|
|
import photon
|
|
import plasma
|
|
from plasma.utils import random_object_id, generate_metadata, write_to_data_buffer, create_object_with_id, create_object
|
|
|
|
from ray import services
|
|
|
|
USE_VALGRIND = False
|
|
PLASMA_STORE_MEMORY = 1000000000
|
|
ID_SIZE = 20
|
|
NUM_CLUSTER_NODES = 2
|
|
|
|
NIL_ACTOR_ID = 20 * b"\xff"
|
|
|
|
# These constants must match the scheduling state enum in task.h.
|
|
TASK_STATUS_WAITING = 1
|
|
TASK_STATUS_SCHEDULED = 2
|
|
TASK_STATUS_QUEUED = 4
|
|
TASK_STATUS_RUNNING = 8
|
|
TASK_STATUS_DONE = 16
|
|
|
|
# These constants are an implementation detail of ray_redis_module.c, so this
|
|
# must be kept in sync with that file.
|
|
DB_CLIENT_PREFIX = "CL:"
|
|
TASK_PREFIX = "TT:"
|
|
|
|
def random_driver_id():
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
def random_task_id():
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
def random_function_id():
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
def random_object_id():
|
|
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
|
|
|
def new_port():
|
|
return random.randint(10000, 65535)
|
|
|
|
class TestGlobalScheduler(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Start one Redis server and N pairs of (plasma, photon)
|
|
node_ip_address = "127.0.0.1"
|
|
redis_port, self.redis_process = services.start_redis(cleanup=False)
|
|
redis_address = services.address(node_ip_address, redis_port)
|
|
# Create a Redis client.
|
|
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
|
|
# Start one global scheduler.
|
|
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
|
|
self.plasma_store_pids = []
|
|
self.plasma_manager_pids = []
|
|
self.local_scheduler_pids = []
|
|
self.plasma_clients = []
|
|
self.photon_clients = []
|
|
|
|
for i in range(NUM_CLUSTER_NODES):
|
|
# Start the Plasma store. Plasma store name is randomly generated.
|
|
plasma_store_name, p2 = plasma.start_plasma_store()
|
|
self.plasma_store_pids.append(p2)
|
|
# Start the Plasma manager.
|
|
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
|
|
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
|
|
self.plasma_manager_pids.append(p3)
|
|
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
|
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
|
|
self.plasma_clients.append(plasma_client)
|
|
# Start the local scheduler.
|
|
local_scheduler_name, p4 = photon.start_local_scheduler(
|
|
plasma_store_name,
|
|
plasma_manager_name=plasma_manager_name,
|
|
plasma_address=plasma_address,
|
|
redis_address=redis_address,
|
|
static_resource_list=[10, 0])
|
|
# Connect to the scheduler.
|
|
photon_client = photon.PhotonClient(local_scheduler_name, NIL_ACTOR_ID)
|
|
self.photon_clients.append(photon_client)
|
|
self.local_scheduler_pids.append(p4)
|
|
|
|
def tearDown(self):
|
|
# Check that the processes are still alive.
|
|
self.assertEqual(self.p1.poll(), None)
|
|
for p2 in self.plasma_store_pids:
|
|
self.assertEqual(p2.poll(), None)
|
|
for p3 in self.plasma_manager_pids:
|
|
self.assertEqual(p3.poll(), None)
|
|
for p4 in self.local_scheduler_pids:
|
|
self.assertEqual(p4.poll(), None)
|
|
|
|
self.assertEqual(self.redis_process.poll(), None)
|
|
|
|
# Kill the global scheduler.
|
|
if USE_VALGRIND:
|
|
self.p1.send_signal(signal.SIGTERM)
|
|
self.p1.wait()
|
|
if self.p1.returncode != 0:
|
|
os._exit(-1)
|
|
else:
|
|
self.p1.kill()
|
|
# Kill local schedulers, plasma managers, and plasma stores.
|
|
for p2 in self.local_scheduler_pids:
|
|
p2.kill()
|
|
for p3 in self.plasma_manager_pids:
|
|
p3.kill()
|
|
for p4 in self.plasma_store_pids:
|
|
p4.kill()
|
|
# Kill Redis. In the event that we are using valgrind, this needs to happen
|
|
# after we kill the global scheduler.
|
|
self.redis_process.kill()
|
|
|
|
def get_plasma_manager_id(self):
|
|
"""Get the db_client_id with client_type equal to plasma_manager.
|
|
|
|
Iterates over all the client table keys, gets the db_client_id for the
|
|
client with client_type matching plasma_manager. Strips the client table
|
|
prefix. TODO(atumanov): write a separate function to get all plasma manager
|
|
client IDs.
|
|
|
|
Returns:
|
|
The db_client_id if one is found and otherwise None.
|
|
"""
|
|
db_client_id = None
|
|
|
|
client_list = self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))
|
|
for client_id in client_list:
|
|
response = self.redis_client.hget(client_id, b"client_type")
|
|
if response == b"plasma_manager":
|
|
db_client_id = client_id
|
|
break
|
|
|
|
return db_client_id
|
|
|
|
def test_task_default_resources(self):
|
|
task1 = photon.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
|
|
self.assertEqual(task1.required_resources(), [1.0, 0.0])
|
|
task2 = photon.Task(random_driver_id(), random_function_id(),
|
|
[random_object_id()], 0, random_task_id(), 0,
|
|
photon.ObjectID(NIL_ACTOR_ID), 0, [1.0, 2.0])
|
|
self.assertEqual(task2.required_resources(), [1.0, 2.0])
|
|
|
|
def test_redis_only_single_task(self):
|
|
"""
|
|
Tests global scheduler functionality by interacting with Redis and checking
|
|
task state transitions in Redis only. TODO(atumanov): implement.
|
|
"""
|
|
# Check precondition for this test:
|
|
# There should be 2n+1 db clients: the global scheduler + one photon and one plasma per node.
|
|
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
|
2 * NUM_CLUSTER_NODES + 1)
|
|
db_client_id = self.get_plasma_manager_id()
|
|
assert(db_client_id != None)
|
|
assert(db_client_id.startswith(b"CL:"))
|
|
db_client_id = db_client_id[len(b"CL:"):] # Remove the CL: prefix.
|
|
|
|
def test_integration_single_task(self):
|
|
# There should be three db clients, the global scheduler, the local
|
|
# scheduler, and the plasma manager.
|
|
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
|
2 * NUM_CLUSTER_NODES + 1)
|
|
|
|
num_return_vals = [0, 1, 2, 3, 5, 10]
|
|
# Insert the object into Redis.
|
|
data_size = 0xf1f0
|
|
metadata_size = 0x40
|
|
plasma_client = self.plasma_clients[0]
|
|
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
|
|
|
|
# Sleep before submitting task to photon.
|
|
time.sleep(0.1)
|
|
# Submit a task to Redis.
|
|
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
|
|
self.photon_clients[0].submit(task)
|
|
time.sleep(0.1)
|
|
# There should now be a task in Redis, and it should get assigned to the
|
|
# local scheduler
|
|
num_retries = 10
|
|
while num_retries > 0:
|
|
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
|
|
self.assertLessEqual(len(task_entries), 1)
|
|
if len(task_entries) == 1:
|
|
task_contents = self.redis_client.hgetall(task_entries[0])
|
|
task_status = int(task_contents[b"state"])
|
|
self.assertTrue(task_status in [TASK_STATUS_WAITING,
|
|
TASK_STATUS_SCHEDULED,
|
|
TASK_STATUS_QUEUED])
|
|
if task_status == TASK_STATUS_QUEUED:
|
|
break
|
|
else:
|
|
print(task_status)
|
|
print("The task has not been scheduled yet, trying again.")
|
|
num_retries -= 1
|
|
time.sleep(1)
|
|
|
|
if num_retries <= 0 and task_status != TASK_STATUS_QUEUED:
|
|
# Failed to submit and schedule a single task -- bail.
|
|
self.tearDown()
|
|
sys.exit(1)
|
|
|
|
def integration_many_tasks_helper(self, timesync=True):
|
|
# There should be three db clients, the global scheduler, the local
|
|
# scheduler, and the plasma manager.
|
|
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
|
2 * NUM_CLUSTER_NODES + 1)
|
|
num_return_vals = [0, 1, 2, 3, 5, 10]
|
|
|
|
# Submit a bunch of tasks to Redis.
|
|
num_tasks = 1000
|
|
for _ in range(num_tasks):
|
|
# Create a new object for each task.
|
|
data_size = np.random.randint(1 << 20)
|
|
metadata_size = np.random.randint(1 << 10)
|
|
plasma_client = self.plasma_clients[0]
|
|
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
|
|
if timesync:
|
|
# Give 10ms for object info handler to fire (long enough to yield CPU).
|
|
time.sleep(0.010)
|
|
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
|
|
self.photon_clients[0].submit(task)
|
|
# Check that there are the correct number of tasks in Redis and that they
|
|
# all get assigned to the local scheduler.
|
|
num_retries = 10
|
|
num_tasks_done = 0
|
|
while num_retries > 0:
|
|
task_entries = self.redis_client.keys("{}*".format(TASK_PREFIX))
|
|
self.assertLessEqual(len(task_entries), num_tasks)
|
|
# First, check if all tasks made it to Redis.
|
|
if len(task_entries) == num_tasks:
|
|
task_contents = [self.redis_client.hgetall(task_entries[i]) for i in range(len(task_entries))]
|
|
task_statuses = [int(contents[b"state"]) for contents in task_contents]
|
|
self.assertTrue(all([
|
|
status in [TASK_STATUS_WAITING,
|
|
TASK_STATUS_SCHEDULED,
|
|
TASK_STATUS_QUEUED] for status in task_statuses
|
|
]))
|
|
num_tasks_done = task_statuses.count(TASK_STATUS_QUEUED)
|
|
num_tasks_scheduled = task_statuses.count(TASK_STATUS_SCHEDULED)
|
|
num_tasks_waiting = task_statuses.count(TASK_STATUS_WAITING)
|
|
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, tasks queued = {}, retries left = {}"
|
|
.format(len(task_entries), num_tasks_waiting,
|
|
num_tasks_scheduled, num_tasks_done, num_retries))
|
|
if all([status == TASK_STATUS_QUEUED for status in task_statuses]):
|
|
# We're done, so pass.
|
|
break
|
|
num_retries -= 1
|
|
time.sleep(0.1)
|
|
|
|
if num_tasks_done != num_tasks:
|
|
# At least one of the tasks failed to schedule.
|
|
self.tearDown()
|
|
sys.exit(2)
|
|
|
|
def test_integration_many_tasks_handler_sync(self):
|
|
self.integration_many_tasks_helper(timesync=True)
|
|
|
|
def test_integration_many_tasks(self):
|
|
# More realistic case: should handle out of order object and task
|
|
# notifications.
|
|
self.integration_many_tasks_helper(timesync=False)
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1:
|
|
# Pop the argument so we don't mess with unittest's own argument parser.
|
|
if sys.argv[-1] == "valgrind":
|
|
arg = sys.argv.pop()
|
|
USE_VALGRIND = True
|
|
print("Using valgrind for tests")
|
|
unittest.main(verbosity=2)
|