diff --git a/python/ray/actor.py b/python/ray/actor.py index 6688a57df..deae7ba4b 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -7,13 +7,14 @@ import inspect import json import numpy as np import random +import redis import traceback import ray.local_scheduler import ray.pickling as pickling import ray.signature as signature import ray.worker -import ray.experimental.state as state +from ray.utils import binary_to_hex, hex_to_binary # This is a variable used by each actor to indicate the IDs of the GPUs that # the worker is currently allowed to use. @@ -105,6 +106,72 @@ def fetch_and_register_actor(key, worker): # the actor. +def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker): + """Attempt to acquire GPUs on a particular local scheduler for an actor. + + Args: + num_gpus: The number of GPUs to acquire. + driver_id: The ID of the driver responsible for creating the actor. + local_scheduler: Information about the local scheduler. + + Returns: + A list of the GPU IDs that were successfully acquired. This should have + length either equal to num_gpus or equal to 0. + """ + local_scheduler_id = local_scheduler["DBClientID"] + local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) + + gpus_to_acquire = [] + + # Attempt to acquire GPU IDs atomically. + with worker.redis_client.pipeline() as pipe: + while True: + try: + # If this key is changed before the transaction below (the multi/exec + # block), then the transaction will not take place. + pipe.watch(local_scheduler_id) + + # Figure out which GPUs are currently in use. + result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use") + gpus_in_use = dict() if result is None else json.loads(result) + all_gpu_ids_in_use = [] + for key in gpus_in_use: + all_gpu_ids_in_use += gpus_in_use[key] + assert len(all_gpu_ids_in_use) <= local_scheduler_total_gpus + assert len(set(all_gpu_ids_in_use)) == len(all_gpu_ids_in_use) + + pipe.multi() + + if local_scheduler_total_gpus - len(all_gpu_ids_in_use) >= num_gpus: + # There are enough available GPUs, so try to reserve some. + all_gpu_ids = set(range(local_scheduler_total_gpus)) + for gpu_id in all_gpu_ids_in_use: + all_gpu_ids.remove(gpu_id) + gpus_to_acquire = list(all_gpu_ids)[:num_gpus] + + # Use the hex driver ID so that the dictionary is JSON serializable. + driver_id_hex = binary_to_hex(driver_id) + if driver_id_hex not in gpus_in_use: + gpus_in_use[driver_id_hex] = [] + gpus_in_use[driver_id_hex] += gpus_to_acquire + + # Stick the updated GPU IDs back in Redis + pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use)) + + pipe.execute() + # If a WatchError is not raised, then the operations should have gone + # through atomically. + break + except redis.WatchError: + # Another client must have changed the watched key between the time we + # started WATCHing it and the pipeline's execution. We should just + # retry. + gpus_to_acquire = [] + continue + + return gpus_to_acquire + + def select_local_scheduler(local_schedulers, num_gpus, worker): """Select a local scheduler to assign this actor to. @@ -121,42 +188,33 @@ def select_local_scheduler(local_schedulers, num_gpus, worker): Exception: An exception is raised if no local scheduler can be found with sufficient resources. """ - # TODO(rkn): We should change this method to have a list of GPU IDs that we - # pop from and push to. The current implementation is not compatible with - # actors releasing GPU resources. + driver_id = worker.task_driver_id.id() + if num_gpus == 0: - local_scheduler_id = random.choice(local_schedulers)[b"ray_client_id"] - gpu_ids = [] + local_scheduler_id = hex_to_binary( + random.choice(local_schedulers)["DBClientID"]) + gpus_aquired = [] else: # All of this logic is for finding a local scheduler that has enough # available GPUs. local_scheduler_id = None # Loop through all of the local schedulers. for local_scheduler in local_schedulers: - # See if there are enough available GPUs on this local scheduler. - local_scheduler_total_gpus = int(float( - local_scheduler[b"num_gpus"].decode("ascii"))) - gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"], - b"gpus_in_use") - gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use) - if gpus_in_use + num_gpus <= local_scheduler_total_gpus: - # Attempt to reserve some GPUs for this actor. - new_gpus_in_use = worker.redis_client.hincrby( - local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus) - if new_gpus_in_use > local_scheduler_total_gpus: - # If we failed to reserve the GPUs, undo the increment. - worker.redis_client.hincrby(local_scheduler[b"ray_client_id"], - b"gpus_in_use", num_gpus) - else: - # We succeeded at reserving the GPUs, so we are done. - local_scheduler_id = local_scheduler[b"ray_client_id"] - gpu_ids = list(range(new_gpus_in_use - num_gpus, new_gpus_in_use)) - break + # Try to reserve enough GPUs on this local scheduler. + gpus_aquired = attempt_to_reserve_gpus(num_gpus, driver_id, + local_scheduler, worker) + if len(gpus_aquired) == num_gpus: + local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"]) + break + else: + # We should have either acquired as many GPUs as we need or none. + assert len(gpus_aquired) == 0 + if local_scheduler_id is None: raise Exception("Could not find a node with enough GPUs to create this " "actor. The local scheduler information is {}." .format(local_schedulers)) - return local_scheduler_id, gpu_ids + return local_scheduler_id, gpus_aquired def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, @@ -183,13 +241,23 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus, worker.function_properties[driver_id][function_id] = (1, num_cpus, num_gpus) + # Get a list of the local schedulers from the client table. + client_table = ray.global_state.client_table() + local_schedulers = [] + for ip_address, clients in client_table.items(): + for client in clients: + if client["ClientType"] == "local_scheduler": + local_schedulers.append(client) # Select a local scheduler for the actor. - local_schedulers = state.get_local_schedulers(worker) local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers, num_gpus, worker) + # Really we should encode this message as a flatbuffer object. However, we're + # having trouble getting that to work. It almost works, but in Python 2.7, + # builder.CreateString fails on byte strings that contain characters outside + # range(128). worker.redis_client.publish("actor_notifications", - actor_id.id() + local_scheduler_id) + actor_id.id() + driver_id + local_scheduler_id) d = {"driver_id": driver_id, "actor_id": actor_id.id(), diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index b417de225..3de63ebaf 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -2,12 +2,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import binascii import pickle import redis -import sys -import ray.local_scheduler +from ray.utils import (decode, binary_to_object_id, binary_to_hex, + hex_to_binary) # Import flatbuffer bindings. from ray.core.generated.TaskInfo import TaskInfo @@ -36,40 +35,6 @@ task_state_mapping = { } -def decode(byte_str): - """Make this unicode in Python 3, otherwise leave it as bytes.""" - if sys.version_info >= (3, 0): - return byte_str.decode("ascii") - else: - return byte_str - - -def binary_to_object_id(binary_object_id): - return ray.local_scheduler.ObjectID(binary_object_id) - - -def binary_to_hex(identifier): - hex_identifier = binascii.hexlify(identifier) - if sys.version_info >= (3, 0): - hex_identifier = hex_identifier.decode() - return hex_identifier - - -def hex_to_binary(hex_identifier): - return binascii.unhexlify(hex_identifier) - - -def get_local_schedulers(worker): - local_schedulers = [] - for client in worker.redis_client.keys("CL:*"): - client_info = worker.redis_client.hgetall(client) - if b"client_type" not in client_info: - continue - if client_info[b"client_type"] == b"local_scheduler": - local_schedulers.append(client_info) - return local_schedulers - - class GlobalState(object): """A class used to interface with the Ray control state. diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py index 36913801e..eb7ec7ee9 100644 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ b/python/ray/global_scheduler/global_scheduler_services.py @@ -7,13 +7,15 @@ import subprocess import time -def start_global_scheduler(redis_address, use_valgrind=False, +def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None): """Start a global scheduler process. Args: redis_address (str): The address of the Redis instance. + node_ip_address: The IP address of the node that this scheduler will run + on. use_valgrind (bool): True if the global scheduler should be started inside of valgrind. If this is True, use_profiler must be False. use_profiler (bool): True if the global scheduler should be started inside @@ -31,7 +33,9 @@ def start_global_scheduler(redis_address, use_valgrind=False, global_scheduler_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), "../core/src/global_scheduler/global_scheduler") - command = [global_scheduler_executable, "-r", redis_address] + command = [global_scheduler_executable, + "-r", redis_address, + "-h", node_ip_address] if use_valgrind: pid = subprocess.Popen(["valgrind", "--track-origins=yes", diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index 4c223a287..f8219e22b 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -71,7 +71,7 @@ class TestGlobalScheduler(unittest.TestCase): port=redis_port) # Start one global scheduler. self.p1 = global_scheduler.start_global_scheduler( - redis_address, use_valgrind=USE_VALGRIND) + redis_address, node_ip_address, use_valgrind=USE_VALGRIND) self.plasma_store_pids = [] self.plasma_manager_pids = [] self.local_scheduler_pids = [] diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 7c9897558..7247cb263 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -4,10 +4,12 @@ from __future__ import print_function import argparse from collections import Counter +import json import logging import redis import time +import ray from ray.services import get_ip_address from ray.services import get_port @@ -15,6 +17,7 @@ from ray.services import get_port from ray.core.generated.SubscribeToDBClientTableReply \ import SubscribeToDBClientTableReply from ray.core.generated.TaskReply import TaskReply +from ray.core.generated.DriverTableMessage import DriverTableMessage # These variables must be kept in sync with the C codebase. # common/common.h @@ -26,6 +29,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE TASK_STATUS_LOST = 32 # common/state/redis.cc PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" +DRIVER_DEATH_CHANNEL = b"driver_deaths" # common/redis_module/ray_redis_module.cc TASK_PREFIX = "TT:" OBJECT_PREFIX = "OL:" @@ -215,6 +219,65 @@ class Monitor(object): # manager. self.live_plasma_managers[db_client_id] = 0 + def driver_removed_handler(self, channel, data): + """Handle a notification that a driver has been removed. + + This releases any GPU resources that were reserved for that driver in + Redis. + """ + message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) + driver_id = message.DriverId() + log.info("Driver {} has been removed.".format(driver_id)) + + # Get a list of the local schedulers. + client_table = ray.global_state.client_table() + local_schedulers = [] + for ip_address, clients in client_table.items(): + for client in clients: + if client["ClientType"] == "local_scheduler": + local_schedulers.append(client) + + # Release any GPU resources that have been reserved for this driver in + # Redis. + for local_scheduler in local_schedulers: + if int(local_scheduler["NumGPUs"]) > 0: + local_scheduler_id = local_scheduler["DBClientID"] + + returned_gpu_ids = [] + + # Perform a transaction to return the GPUs. + with self.redis.pipeline() as pipe: + while True: + try: + # If this key is changed before the transaction below (the + # multi/exec block), then the transaction will not take place. + pipe.watch(local_scheduler_id) + + result = pipe.hget(local_scheduler_id, "gpus_in_use") + gpus_in_use = dict() if result is None else json.loads(result) + + driver_id_hex = ray.utils.binary_to_hex(driver_id) + if driver_id_hex in gpus_in_use: + returned_gpu_ids = gpus_in_use.pop(driver_id_hex) + + pipe.multi() + + pipe.hset(local_scheduler_id, "gpus_in_use", + json.dumps(gpus_in_use)) + + pipe.execute() + # If a WatchError is not raise, then the operations should have + # gone through atomically. + break + except redis.WatchError: + # Another client must have changed the watched key between the + # time we started WATCHing it and the pipeline's execution. We + # should just retry. + continue + + log.info("Driver {} is returning GPU IDs {} to local scheduler {}." + .format(driver_id, returned_gpu_ids, local_scheduler_id)) + def process_messages(self): """Process all messages ready in the subscription channels. @@ -244,6 +307,12 @@ class Monitor(object): assert(self.subscribed[channel]) # The message was a notification from the db_client table. message_handler = self.db_client_notification_handler + elif channel == DRIVER_DEATH_CHANNEL: + assert(self.subscribed[channel]) + # The message was a notification that a driver was removed. + message_handler = self.driver_removed_handler + else: + raise Exception("This code should be unreachable.") # Call the handler. assert(message_handler is not None) @@ -258,6 +327,7 @@ class Monitor(object): # Initialize the subscription channel. self.subscribe(DB_CLIENT_TABLE_NAME) self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) + self.subscribe(DRIVER_DEATH_CHANNEL) # Scan the database table for dead database clients. NOTE: This must be # called before reading any messages from the subscription channel. This @@ -326,5 +396,8 @@ if __name__ == "__main__": redis_ip_address = get_ip_address(args.redis_address) redis_port = get_port(args.redis_address) + # Initialize the global state. + ray.global_state._initialize_global_state(redis_ip_address, redis_port) + monitor = Monitor(redis_ip_address, redis_port) monitor.run() diff --git a/python/ray/services.py b/python/ray/services.py index 5af23e217..cf16c6cbc 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -372,7 +372,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, this process will be killed by services.cleanup() when the Python process that imported services exits. """ - p = global_scheduler.start_global_scheduler(redis_address, + p = global_scheduler.start_global_scheduler(redis_address, node_ip_address, stdout_file=stdout_file, stderr_file=stderr_file) if cleanup: @@ -767,7 +767,10 @@ def start_ray_processes(address_info=None, if num_workers is not None: workers_per_local_scheduler = num_local_schedulers * [num_workers] else: - workers_per_local_scheduler = num_local_schedulers * [psutil.cpu_count()] + workers_per_local_scheduler = [] + for cpus in num_cpus: + workers_per_local_scheduler.append(cpus if cpus is not None + else psutil.cpu_count()) if address_info is None: address_info = {} diff --git a/python/ray/test/multi_node_tests.py b/python/ray/test/multi_node_tests.py new file mode 100644 index 000000000..6b0277541 --- /dev/null +++ b/python/ray/test/multi_node_tests.py @@ -0,0 +1,86 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import redis +import time + +import ray + +EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY" +"""This key is used internally within this file for coordinating drivers.""" + + +def _wait_for_nodes_to_join(num_nodes, timeout=20): + """Wait until the nodes have joined the cluster. + + This will wait until exactly num_nodes have joined the cluster and each node + has a local scheduler and a plasma manager. + + Args: + num_nodes: The number of nodes to wait for. + timeout: The amount of time in seconds to wait before failing. + + Raises: + Exception: An exception is raised if too many nodes join the cluster or if + the timeout expires while we are waiting. + """ + start_time = time.time() + while time.time() - start_time < timeout: + client_table = ray.global_state.client_table() + num_ready_nodes = len(client_table) + if num_ready_nodes == num_nodes: + ready = True + # Check that for each node, a local scheduler and a plasma manager are + # present. + for ip_address, clients in client_table.items(): + client_types = [client["ClientType"] for client in clients] + if "local_scheduler" not in client_types: + ready = False + if "plasma_manager" not in client_types: + ready = False + if ready: + return + if num_ready_nodes > num_nodes: + # Too many nodes have joined. Something must be wrong. + raise Exception("{} nodes have joined the cluster, but we were " + "expecting {} nodes.".format(num_ready_nodes, num_nodes)) + time.sleep(0.1) + + # If we get here then we timed out. + raise Exception("Timed out while waiting for {} nodes to join. Only {} " + "nodes have joined so far.".format(num_ready_nodes, + num_nodes)) + + +def _broadcast_event(event_name, redis_address): + """Broadcast an event. + + Args: + event_name: The name of the event to wait for. + redis_address: The address of the Redis server to use for synchronization. + + This is used to synchronize drivers for the multi-node tests. + """ + redis_host, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) + redis_client.rpush(EVENT_KEY, event_name) + + +def _wait_for_event(event_name, redis_address, extra_buffer=1): + """Block until an event has been broadcast. + + Args: + event_name: The name of the event to wait for. + redis_address: The address of the Redis server to use for synchronization. + extra_buffer: An amount of time in seconds to wait after the event. + + This is used to synchronize drivers for the multi-node tests. + """ + redis_host, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) + while True: + event_names = redis_client.lrange(EVENT_KEY, 0, -1) + if event_name.encode("ascii") in event_names: + break + time.sleep(extra_buffer) diff --git a/python/ray/utils.py b/python/ray/utils.py new file mode 100644 index 000000000..ea86cf814 --- /dev/null +++ b/python/ray/utils.py @@ -0,0 +1,31 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import binascii +import sys + +import ray.local_scheduler + + +def decode(byte_str): + """Make this unicode in Python 3, otherwise leave it as bytes.""" + if sys.version_info >= (3, 0): + return byte_str.decode("ascii") + else: + return byte_str + + +def binary_to_object_id(binary_object_id): + return ray.local_scheduler.ObjectID(binary_object_id) + + +def binary_to_hex(identifier): + hex_identifier = binascii.hexlify(identifier) + if sys.version_info >= (3, 0): + hex_identifier = hex_identifier.decode() + return hex_identifier + + +def hex_to_binary(hex_identifier): + return binascii.unhexlify(hex_identifier) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 04fd71bed..6b77693b6 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -61,6 +61,7 @@ add_library(common STATIC state/object_table.cc state/task_table.cc state/db_client_table.cc + state/driver_table.cc state/actor_notification_table.cc state/local_scheduler_table.cc state/error_table.cc diff --git a/src/common/common.cc b/src/common/common.cc index d5601d5e5..91b36797c 100644 --- a/src/common/common.cc +++ b/src/common/common.cc @@ -47,6 +47,10 @@ bool DBClientID_equal(DBClientID first_id, DBClientID second_id) { return UNIQUE_ID_EQ(first_id, second_id); } +bool WorkerID_equal(WorkerID first_id, WorkerID second_id) { + return UNIQUE_ID_EQ(first_id, second_id); +} + char *ObjectID_to_string(ObjectID obj_id, char *id_string, int id_length) { CHECK(id_length >= ID_STRING_SIZE); static const char hex[] = "0123456789abcdef"; diff --git a/src/common/common.h b/src/common/common.h index a04fa8f38..ca68692fa 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -153,7 +153,9 @@ extern const UniqueID NIL_ID; UniqueID globally_unique_id(void); #define NIL_OBJECT_ID NIL_ID +#define NIL_WORKER_ID NIL_ID +/** The object ID is the type used to identify objects. */ typedef UniqueID ObjectID; #ifdef __cplusplus @@ -202,6 +204,18 @@ bool ObjectID_equal(ObjectID first_id, ObjectID second_id); */ bool ObjectID_is_nil(ObjectID id); +/** The worker ID is the ID of a worker or driver. */ +typedef UniqueID WorkerID; + +/** + * Compare two worker IDs. + * + * @param first_id The first worker ID to compare. + * @param second_id The first worker ID to compare. + * @return True if the worker IDs are the same and false otherwise. + */ +bool WorkerID_equal(WorkerID first_id, WorkerID second_id); + typedef UniqueID DBClientID; /** diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs index cde6b1df5..78c5ec2c2 100644 --- a/src/common/format/common.fbs +++ b/src/common/format/common.fbs @@ -139,3 +139,8 @@ table ResultTableReply { } root_type ResultTableReply; + +table DriverTableMessage { + // The driver ID of the driver that died. + driver_id: string; +} diff --git a/src/common/state/actor_notification_table.h b/src/common/state/actor_notification_table.h index cb4adafdc..87d7b7633 100644 --- a/src/common/state/actor_notification_table.h +++ b/src/common/state/actor_notification_table.h @@ -5,20 +5,16 @@ #include "db.h" #include "table.h" -typedef struct { - /** The ID of the actor. */ - ActorID actor_id; - /** The ID of the local scheduler that is responsible for the actor. */ - DBClientID local_scheduler_id; -} ActorInfo; - /* * ==== Subscribing to the actor notification table ==== */ /* Callback for subscribing to the local scheduler table. */ -typedef void (*actor_notification_table_subscribe_callback)(ActorInfo info, - void *user_context); +typedef void (*actor_notification_table_subscribe_callback)( + ActorID actor_id, + WorkerID driver_id, + DBClientID local_scheduler_id, + void *user_context); /** * Register a callback to process actor notification events. diff --git a/src/common/state/driver_table.cc b/src/common/state/driver_table.cc new file mode 100644 index 000000000..9e951b805 --- /dev/null +++ b/src/common/state/driver_table.cc @@ -0,0 +1,21 @@ +#include "driver_table.h" +#include "redis.h" + +void driver_table_subscribe(DBHandle *db_handle, + driver_table_subscribe_callback subscribe_callback, + void *subscribe_context, + RetryInfo *retry) { + DriverTableSubscribeData *sub_data = + (DriverTableSubscribeData *) malloc(sizeof(DriverTableSubscribeData)); + sub_data->subscribe_callback = subscribe_callback; + sub_data->subscribe_context = subscribe_context; + init_table_callback(db_handle, NIL_ID, __func__, sub_data, retry, NULL, + redis_driver_table_subscribe, NULL); +} + +void driver_table_send_driver_death(DBHandle *db_handle, + WorkerID driver_id, + RetryInfo *retry) { + init_table_callback(db_handle, driver_id, __func__, NULL, retry, NULL, + redis_driver_table_send_driver_death, NULL); +} diff --git a/src/common/state/driver_table.h b/src/common/state/driver_table.h new file mode 100644 index 000000000..c8c6a6c32 --- /dev/null +++ b/src/common/state/driver_table.h @@ -0,0 +1,50 @@ +#ifndef DRIVER_TABLE_H +#define DRIVER_TABLE_H + +#include "db.h" +#include "table.h" +#include "task.h" + +/* + * ==== Subscribing to the driver table ==== + */ + +/* Callback for subscribing to the driver table. */ +typedef void (*driver_table_subscribe_callback)(WorkerID driver_id, + void *user_context); + +/** + * Register a callback for a driver table event. + * + * @param db_handle Database handle. + * @param subscribe_callback Callback that will be called when the driver event + * happens. + * @param subscribe_context Context that will be passed into the + * subscribe_callback. + * @param retry Information about retrying the request to the database. + * @return Void. + */ +void driver_table_subscribe(DBHandle *db_handle, + driver_table_subscribe_callback subscribe_callback, + void *subscribe_context, + RetryInfo *retry); + +/* Data that is needed to register driver table subscribe callbacks with the + * state database. */ +typedef struct { + driver_table_subscribe_callback subscribe_callback; + void *subscribe_context; +} DriverTableSubscribeData; + +/** + * Send driver death update to all subscribers. + * + * @param db_handle Database handle. + * @param driver_id The ID of the driver that died. + * @param retry Information about retrying the request to the database. + */ +void driver_table_send_driver_death(DBHandle *db_handle, + WorkerID driver_id, + RetryInfo *retry); + +#endif /* DRIVER_TABLE_H */ diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index 3bbfacf41..cdb29693c 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -17,6 +17,7 @@ extern "C" { #include "db.h" #include "db_client_table.h" #include "actor_notification_table.h" +#include "driver_table.h" #include "local_scheduler_table.h" #include "object_table.h" #include "task.h" @@ -1089,6 +1090,90 @@ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) { } } +void redis_driver_table_subscribe_callback(redisAsyncContext *c, + void *r, + void *privdata) { + REDIS_CALLBACK_HEADER(db, callback_data, r); + + redisReply *reply = (redisReply *) r; + CHECK(reply->type == REDIS_REPLY_ARRAY); + CHECK(reply->elements == 3); + redisReply *message_type = reply->element[0]; + LOG_DEBUG("Driver table subscribe callback, message %s", message_type->str); + + if (strcmp(message_type->str, "message") == 0) { + /* Handle a driver heartbeat. Parse the payload and call the subscribe + * callback. */ + auto message = + flatbuffers::GetRoot(reply->element[2]->str); + /* Extract the client ID. */ + WorkerID driver_id = from_flatbuf(message->driver_id()); + + /* Call the subscribe callback. */ + DriverTableSubscribeData *data = + (DriverTableSubscribeData *) callback_data->data; + if (data->subscribe_callback) { + data->subscribe_callback(driver_id, data->subscribe_context); + } + } else if (strcmp(message_type->str, "subscribe") == 0) { + /* The reply for the initial SUBSCRIBE command. */ + CHECK(callback_data->done_callback == NULL); + /* If the initial SUBSCRIBE was successful, clean up the timer, but don't + * destroy the callback data. */ + event_loop_remove_timer(db->loop, callback_data->timer_id); + + } else { + LOG_FATAL("Unexpected reply type from driver subscribe."); + } +} + +void redis_driver_table_subscribe(TableCallbackData *callback_data) { + DBHandle *db = callback_data->db_handle; + int status = redisAsyncCommand( + db->sub_context, redis_driver_table_subscribe_callback, + (void *) callback_data->timer_id, "SUBSCRIBE driver_deaths"); + if ((status == REDIS_ERR) || db->sub_context->err) { + LOG_REDIS_DEBUG(db->sub_context, "error in redis_driver_table_subscribe"); + } +} + +void redis_driver_table_send_driver_death_callback(redisAsyncContext *c, + void *r, + void *privdata) { + REDIS_CALLBACK_HEADER(db, callback_data, r); + + redisReply *reply = (redisReply *) r; + CHECK(reply->type == REDIS_REPLY_INTEGER); + LOG_DEBUG("%" PRId64 " subscribers received this publish.\n", reply->integer); + /* At the very least, the local scheduler that publishes this message should + * also receive it. */ + CHECK(reply->integer >= 1); + + CHECK(callback_data->done_callback == NULL); + /* Clean up the timer and callback. */ + destroy_timer_callback(db->loop, callback_data); +} + +void redis_driver_table_send_driver_death(TableCallbackData *callback_data) { + DBHandle *db = callback_data->db_handle; + WorkerID driver_id = callback_data->id; + + /* Create a flatbuffer object to serialize and publish. */ + flatbuffers::FlatBufferBuilder fbb; + /* Create the flatbuffers message. */ + auto message = CreateDriverTableMessage(fbb, to_flatbuf(fbb, driver_id)); + fbb.Finish(message); + + int status = redisAsyncCommand( + db->context, redis_driver_table_send_driver_death_callback, + (void *) callback_data->timer_id, "PUBLISH driver_deaths %b", + fbb.GetBufferPointer(), fbb.GetSize()); + if ((status == REDIS_ERR) || db->context->err) { + LOG_REDIS_DEBUG(db->context, + "error in redis_driver_table_send_driver_death"); + } +} + void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; /* NOTE(swang): We purposefully do not provide a callback, leaving the table @@ -1124,15 +1209,20 @@ void redis_actor_notification_table_subscribe_callback(redisAsyncContext *c, redisReply *payload = reply->element[2]; ActorNotificationTableSubscribeData *data = (ActorNotificationTableSubscribeData *) callback_data->data; - ActorInfo info; - /* The payload should be the concatenation of these two structs. */ - CHECK(sizeof(info.actor_id) + sizeof(info.local_scheduler_id) == + /* The payload should be the concatenation of three IDs. */ + ActorID actor_id; + WorkerID driver_id; + DBClientID local_scheduler_id; + CHECK(sizeof(actor_id) + sizeof(driver_id) + sizeof(local_scheduler_id) == payload->len); - memcpy(&info.actor_id, payload->str, sizeof(info.actor_id)); - memcpy(&info.local_scheduler_id, payload->str + sizeof(info.actor_id), - sizeof(info.local_scheduler_id)); + memcpy(&actor_id, payload->str, sizeof(actor_id)); + memcpy(&driver_id, payload->str + sizeof(actor_id), sizeof(driver_id)); + memcpy(&local_scheduler_id, + payload->str + sizeof(actor_id) + sizeof(driver_id), + sizeof(local_scheduler_id)); if (data->subscribe_callback) { - data->subscribe_callback(info, data->subscribe_context); + data->subscribe_callback(actor_id, driver_id, local_scheduler_id, + data->subscribe_context); } } else if (strcmp(message_type->str, "subscribe") == 0) { /* The reply for the initial SUBSCRIBE command. */ diff --git a/src/common/state/redis.h b/src/common/state/redis.h index e8af9c643..cada086ae 100644 --- a/src/common/state/redis.h +++ b/src/common/state/redis.h @@ -253,6 +253,24 @@ void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data); */ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data); +/** + * Subscribe to updates from the driver table. + * + * @param callback_data Data structure containing redis connection and timeout + * information. + * @return Void. + */ +void redis_driver_table_subscribe(TableCallbackData *callback_data); + +/** + * Publish an update to the driver table. + * + * @param callback_data Data structure containing redis connection and timeout + * information. + * @return Void. + */ +void redis_driver_table_send_driver_death(TableCallbackData *callback_data); + void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data); /** diff --git a/src/common/task.h b/src/common/task.h index 1bcc3cc7f..986717860 100644 --- a/src/common/task.h +++ b/src/common/task.h @@ -16,7 +16,6 @@ struct TaskBuilder; #define NIL_TASK_ID NIL_ID #define NIL_ACTOR_ID NIL_ID #define NIL_FUNCTION_ID NIL_ID -#define NIL_WORKER_ID NIL_ID typedef UniqueID FunctionID; diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index 6df8ea73a..e7956e289 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -61,14 +61,15 @@ void assign_task_to_local_scheduler(GlobalSchedulerState *state, } GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop, + const char *node_ip_address, const char *redis_addr, int redis_port) { GlobalSchedulerState *state = (GlobalSchedulerState *) malloc(sizeof(GlobalSchedulerState)); /* Must initialize state to 0. Sets hashmap head(s) to NULL. */ memset(state, 0, sizeof(GlobalSchedulerState)); - state->db = - db_connect(redis_addr, redis_port, "global_scheduler", ":", 0, NULL); + state->db = db_connect(redis_addr, redis_port, "global_scheduler", + node_ip_address, 0, NULL); db_attach(state->db, loop, false); utarray_new(state->local_schedulers, &local_scheduler_icd); state->policy_state = GlobalSchedulerPolicyState_init(); @@ -416,9 +417,12 @@ int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) { return HEARTBEAT_TIMEOUT_MILLISECONDS; } -void start_server(const char *redis_addr, int redis_port) { +void start_server(const char *node_ip_address, + const char *redis_addr, + int redis_port) { event_loop *loop = event_loop_create(); - g_state = GlobalSchedulerState_init(loop, redis_addr, redis_port); + g_state = + GlobalSchedulerState_init(loop, node_ip_address, redis_addr, redis_port); /* TODO(rkn): subscribe to notifications from the object table. */ /* Subscribe to notifications about new local schedulers. TODO(rkn): this * needs to also get all of the clients that registered with the database @@ -456,12 +460,17 @@ int main(int argc, char *argv[]) { signal(SIGTERM, signal_handler); /* IP address and port of redis. */ char *redis_addr_port = NULL; + /* The IP address of the node that this global scheduler is running on. */ + char *node_ip_address = NULL; int c; - while ((c = getopt(argc, argv, "s:m:h:p:r:")) != -1) { + while ((c = getopt(argc, argv, "h:r:")) != -1) { switch (c) { case 'r': redis_addr_port = optarg; break; + case 'h': + node_ip_address = optarg; + break; default: LOG_ERROR("unknown option %c", c); exit(-1); @@ -472,8 +481,11 @@ int main(int argc, char *argv[]) { if (!redis_addr_port || parse_ip_addr_port(redis_addr_port, redis_addr, &redis_port) == -1) { LOG_ERROR( - "need to specify redis address like 127.0.0.1:6379 with -r switch"); + "specify the redis address like 127.0.0.1:6379 with the -r switch"); exit(-1); } - start_server(redis_addr, redis_port); + if (!node_ip_address) { + LOG_FATAL("specify the node IP address with the -h switch"); + } + start_server(node_ip_address, redis_addr, redis_port); } diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 6a96df5cb..5ca04b371 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -18,6 +18,7 @@ #include "local_scheduler_algorithm.h" #include "state/actor_notification_table.h" #include "state/db.h" +#include "state/driver_table.h" #include "state/task_table.h" #include "state/object_table.h" #include "state/error_table.h" @@ -67,7 +68,8 @@ int force_kill_worker(event_loop *loop, timer_id id, void *context) { /** * Kill a worker, if it is a child process, and clean up all of its associated - * state. + * state. Note that this function is also called on drivers, but it should not + * actually send a kill signal to drivers. * * @param worker A pointer to the worker we want to kill. * @param cleanup A bool representing whether we're cleaning up the entire local @@ -89,7 +91,13 @@ void kill_worker(LocalSchedulerState *state, CHECK(it == state->workers.end()); /* Erase the algorithm state's reference to the worker. */ - handle_worker_removed(state, state->algorithm_state, worker); + if (ActorID_equal(worker->actor_id, NIL_ACTOR_ID)) { + handle_worker_removed(state, state->algorithm_state, worker); + } else { + /* Let the scheduling algorithm process the absence of this worker. */ + handle_actor_worker_disconnect(state, state->algorithm_state, + worker->actor_id); + } /* Remove the client socket from the event loop so that we don't process the * SIGPIPE when the worker is killed. */ @@ -99,6 +107,8 @@ void kill_worker(LocalSchedulerState *state, * process, use it to send a kill signal. */ bool free_worker = true; if (worker->is_child && worker->pid != 0) { + /* If worker is a driver, we should not enter this condition because + * worker->pid should be 0. */ if (cleanup) { /* If we're exiting the local scheduler anyway, it's okay to force kill * the worker immediately. Wait for the process to exit. */ @@ -425,10 +435,18 @@ void update_dynamic_resources(LocalSchedulerState *state, print_resource_info(state, spec); } +bool is_driver_alive(LocalSchedulerState *state, WorkerID driver_id) { + return state->removed_drivers.count(driver_id) == 0; +} + void assign_task_to_worker(LocalSchedulerState *state, TaskSpec *spec, int64_t task_spec_size, LocalSchedulerClient *worker) { + /* Make sure the driver for this task is still alive. */ + WorkerID driver_id = TaskSpec_driver_id(spec); + CHECK(is_driver_alive(state, driver_id)); + /* Construct a flatbuffer object to send to the worker. */ flatbuffers::FlatBufferBuilder fbb; auto message = @@ -649,8 +667,15 @@ void send_client_register_reply(LocalSchedulerState *state, void handle_client_register(LocalSchedulerState *state, LocalSchedulerClient *worker, const RegisterClientRequest *message) { + /* Make sure this worker hasn't already registered. */ + CHECK(!worker->registered); + worker->registered = true; + worker->is_worker = message->is_worker(); + CHECK(WorkerID_equal(worker->client_id, NIL_WORKER_ID)); + worker->client_id = from_flatbuf(message->client_id()); + /* Register the worker or driver. */ - if (message->is_worker()) { + if (worker->is_worker) { /* Update the actor mapping with the actor ID of the worker (if an actor is * running on the worker). */ worker->pid = message->worker_pid(); @@ -683,12 +708,74 @@ void handle_client_register(LocalSchedulerState *state, state->child_pids.erase(it); LOG_DEBUG("Found matching child pid %d", worker->pid); } + + /* If the worker is an actor that corresponds to a driver that has been + * removed, then kill the worker. */ + if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) { + WorkerID driver_id = state->actor_mapping[actor_id].driver_id; + if (state->removed_drivers.count(driver_id) == 1) { + kill_worker(state, worker, false); + } + } } else { /* Register the driver. Currently we don't do anything here. */ } } -/* End of the cleanup code. */ +void handle_driver_removed_callback(WorkerID driver_id, void *user_context) { + LocalSchedulerState *state = (LocalSchedulerState *) user_context; + + /* Kill any actors that were created by the removed driver, and kill any + * workers that are currently running tasks from the dead driver. */ + auto it = state->workers.begin(); + while (it != state->workers.end()) { + /* Increment the iterator by one before calling kill_worker, because + * kill_worker will invalidate the iterator. Note that this requires + * knowledge of the particular container that we are iterating over (in this + * case it is a list). */ + auto next_it = it; + next_it++; + + ActorID actor_id = (*it)->actor_id; + Task *task = (*it)->task_in_progress; + + if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) { + /* This is an actor. */ + CHECK(state->actor_mapping.count(actor_id) == 1); + if (WorkerID_equal(state->actor_mapping[actor_id].driver_id, driver_id)) { + /* This actor was created by the removed driver, so kill the actor. */ + LOG_DEBUG("Killing an actor for a removed driver."); + kill_worker(state, *it, false); + break; + } + } else if (task != NULL) { + if (WorkerID_equal(TaskSpec_driver_id(Task_task_spec(task)), driver_id)) { + LOG_DEBUG("Killing a worker executing a task for a removed driver."); + kill_worker(state, *it, false); + break; + } + } + + it = next_it; + } + + /* Add the driver to a list of dead drivers. */ + state->removed_drivers.insert(driver_id); + + /* Notify the scheduling algorithm that the driver has been removed. It should + * remove tasks for that driver from its data structures. */ + handle_driver_removed(state, state->algorithm_state, driver_id); +} + +void handle_client_disconnect(LocalSchedulerState *state, + LocalSchedulerClient *worker) { + if (!worker->registered || worker->is_worker) { + } else { + /* In this case, a driver is disconecting. */ + driver_table_send_driver_death(state->db, worker->client_id, NULL); + } + kill_worker(state, worker, false); +} void process_message(event_loop *loop, int client_sock, @@ -786,12 +873,7 @@ void process_message(event_loop *loop, } break; case DISCONNECT_CLIENT: { LOG_INFO("Disconnecting client on fd %d", client_sock); - kill_worker(state, worker, false); - if (!ActorID_equal(worker->actor_id, NIL_ACTOR_ID)) { - /* Let the scheduling algorithm process the absence of this worker. */ - handle_actor_worker_disconnect(state, state->algorithm_state, - worker->actor_id); - } + handle_client_disconnect(state, worker); } break; case MessageType_NotifyUnblocked: { if (worker->task_in_progress != NULL) { @@ -827,6 +909,11 @@ void new_client_connection(event_loop *loop, * scheduler state. */ LocalSchedulerClient *worker = new LocalSchedulerClient(); worker->sock = new_socket; + worker->registered = false; + /* We don't know whether this is a worker or not, so just initialize is_worker + * to false. */ + worker->is_worker = true; + worker->client_id = NIL_WORKER_ID; worker->task_in_progress = NULL; worker->is_blocked = false; worker->pid = 0; @@ -857,10 +944,21 @@ void signal_handler(int signal) { } } +/* End of the cleanup code. */ + void handle_task_scheduled_callback(Task *original_task, void *subscribe_context) { LocalSchedulerState *state = (LocalSchedulerState *) subscribe_context; TaskSpec *spec = Task_task_spec(original_task); + + /* If the driver for this task has been removed, then don't bother telling the + * scheduling algorithm. */ + WorkerID driver_id = TaskSpec_driver_id(spec); + if (!is_driver_alive(state, driver_id)) { + LOG_DEBUG("Ignoring scheduled task for removed driver.") + return; + } + if (ActorID_equal(TaskSpec_actor_id(spec), NIL_ACTOR_ID)) { /* This task does not involve an actor. Handle it normally. */ handle_task_scheduled(state, state->algorithm_state, spec, @@ -884,10 +982,17 @@ void handle_task_scheduled_callback(Task *original_task, * for creating the actor. * @return Void. */ -void handle_actor_creation_callback(ActorInfo info, void *context) { - ActorID actor_id = info.actor_id; - DBClientID local_scheduler_id = info.local_scheduler_id; +void handle_actor_creation_callback(ActorID actor_id, + WorkerID driver_id, + DBClientID local_scheduler_id, + void *context) { LocalSchedulerState *state = (LocalSchedulerState *) context; + + /* If the driver has been removed, don't bother doing anything. */ + if (state->removed_drivers.count(driver_id) == 1) { + return; + } + /* Make sure the actor entry is not already present in the actor map table. * TODO(rkn): We will need to remove this check to handle the case where the * corresponding publish is retried and the case in which a task that creates @@ -897,9 +1002,10 @@ void handle_actor_creation_callback(ActorInfo info, void *context) { * Currently this is never removed (except when the local scheduler state is * deleted). */ ActorMapEntry entry; - entry.actor_id = actor_id; entry.local_scheduler_id = local_scheduler_id; + entry.driver_id = driver_id; state->actor_mapping[actor_id] = entry; + /* If this local scheduler is responsible for the actor, then start a new * worker for the actor. */ if (DBClientID_equal(local_scheduler_id, get_db_client_id(state->db))) { @@ -960,6 +1066,11 @@ void start_server(const char *node_ip_address, actor_notification_table_subscribe( g_state->db, handle_actor_creation_callback, g_state, NULL); } + /* Subscribe to notifications about removed drivers. */ + if (g_state->db != NULL) { + driver_table_subscribe(g_state->db, handle_driver_removed_callback, g_state, + NULL); + } /* Create a timer for publishing information about the load on the local * scheduler to the local scheduler table. This message also serves as a * heartbeat. */ diff --git a/src/local_scheduler/local_scheduler.h b/src/local_scheduler/local_scheduler.h index 600de58ac..5d1b40f4f 100644 --- a/src/local_scheduler/local_scheduler.h +++ b/src/local_scheduler/local_scheduler.h @@ -26,6 +26,14 @@ void new_client_connection(event_loop *loop, void *context, int events); +/** + * Check if a driver is still alive. + * + * @param driver_id The ID of the driver. + * @return True if the driver is still alive and false otherwise. + */ +bool is_driver_alive(WorkerID driver_id); + /** * This function can be called by the scheduling algorithm to assign a task * to a worker. diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc index 6ea2d47d5..4b9832ee0 100644 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ b/src/local_scheduler/local_scheduler_algorithm.cc @@ -964,6 +964,9 @@ void handle_worker_available(LocalSchedulerState *state, void handle_worker_removed(LocalSchedulerState *state, SchedulingAlgorithmState *algorithm_state, LocalSchedulerClient *worker) { + /* Make sure this is not an actor. */ + CHECK(ActorID_equal(worker->actor_id, NIL_ACTOR_ID)); + /* Make sure that we remove the worker at most once. */ int num_times_removed = 0; @@ -1141,6 +1144,59 @@ void handle_object_removed(LocalSchedulerState *state, } } +void handle_driver_removed(LocalSchedulerState *state, + SchedulingAlgorithmState *algorithm_state, + WorkerID driver_id) { + /* Loop over fetch requests. This must be done before we clean up the waiting + * task queue and the dispatch task queue because this map contains iterators + * for those lists, which will be invalidated when we clean up those lists.*/ + for (auto it = algorithm_state->remote_objects.begin(); + it != algorithm_state->remote_objects.end();) { + /* Loop over the tasks that are waiting for this object and remove the tasks + * for the removed driver. */ + auto task_it_it = it->second.dependent_tasks.begin(); + while (task_it_it != it->second.dependent_tasks.end()) { + /* If the dependent task was a task for the removed driver, remove it from + * this vector. */ + TaskSpec *spec = (*task_it_it)->spec; + if (WorkerID_equal(TaskSpec_driver_id(spec), driver_id)) { + task_it_it = it->second.dependent_tasks.erase(task_it_it); + } else { + task_it_it++; + } + } + /* If there are no more dependent tasks for this object, then remove the + * ObjectEntry. */ + if (it->second.dependent_tasks.size() == 0) { + it = algorithm_state->remote_objects.erase(it); + } else { + it++; + } + } + + /* Remove this driver's tasks from the waiting task queue. */ + auto it = algorithm_state->waiting_task_queue->begin(); + while (it != algorithm_state->waiting_task_queue->end()) { + if (WorkerID_equal(TaskSpec_driver_id(it->spec), driver_id)) { + it = algorithm_state->waiting_task_queue->erase(it); + } else { + it++; + } + } + + /* Remove this driver's tasks from the dispatch task queue. */ + it = algorithm_state->dispatch_task_queue->begin(); + while (it != algorithm_state->dispatch_task_queue->end()) { + if (WorkerID_equal(TaskSpec_driver_id(it->spec), driver_id)) { + it = algorithm_state->dispatch_task_queue->erase(it); + } else { + it++; + } + } + + /* TODO(rkn): Should we clean up the actor data structures? */ +} + int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state) { return algorithm_state->waiting_task_queue->size(); } diff --git a/src/local_scheduler/local_scheduler_algorithm.h b/src/local_scheduler/local_scheduler_algorithm.h index 0c1025b9d..5bdadc670 100644 --- a/src/local_scheduler/local_scheduler_algorithm.h +++ b/src/local_scheduler/local_scheduler_algorithm.h @@ -230,6 +230,20 @@ void handle_worker_unblocked(LocalSchedulerState *state, SchedulingAlgorithmState *algorithm_state, LocalSchedulerClient *worker); +/** + * Process the fact that a driver has been removed. This will remove all of the + * tasks for that driver from the scheduling algorithm's internal data + * structures. + * + * @param state The state of the local scheduler. + * @param algorithm_state State maintained by the scheduling algorithm. + * @param driver_id The ID of the driver that was removed. + * @return Void. + */ +void handle_driver_removed(LocalSchedulerState *state, + SchedulingAlgorithmState *algorithm_state, + WorkerID driver_id); + /** * This function fetches queued task's missing object dependencies. It is * called every LOCAL_SCHEDULER_FETCH_TIMEOUT_MILLISECONDS. diff --git a/src/local_scheduler/local_scheduler_shared.h b/src/local_scheduler/local_scheduler_shared.h index bffa23fad..4dc3d5dfe 100644 --- a/src/local_scheduler/local_scheduler_shared.h +++ b/src/local_scheduler/local_scheduler_shared.h @@ -8,7 +8,9 @@ #include "utarray.h" #include "uthash.h" +#include #include +#include #include /* These are needed to define the UT_arrays. */ @@ -17,8 +19,8 @@ extern UT_icd task_ptr_icd; /** This struct is used to maintain a mapping from actor IDs to the ID of the * local scheduler that is responsible for the actor. */ struct ActorMapEntry { - /** The ID of the actor. This is used as a key in the hash table. */ - ActorID actor_id; + /** The ID of the driver that created the actor. */ + WorkerID driver_id; /** The ID of the local scheduler that is responsible for the actor. */ DBClientID local_scheduler_id; }; @@ -47,7 +49,10 @@ struct LocalSchedulerState { /** List of workers available to this node. This is used to free the worker * structs when we free the scheduler state and also to access the worker * structs in the tests. */ - std::vector workers; + std::list workers; + /** A set of driver IDs corresponding to drivers that have been removed. This + * is used to make sure we don't execute any tasks belong to dead drivers. */ + std::unordered_set removed_drivers; /** List of the process IDs for child processes (workers) started by the * local scheduler that have not sent a REGISTER_PID message yet. */ std::vector child_pids; @@ -75,6 +80,13 @@ struct LocalSchedulerState { struct LocalSchedulerClient { /** The socket used to communicate with the client. */ int sock; + /** True if the client has registered and false otherwise. */ + bool registered; + /** True if the client is a worker and false if it is a driver. */ + bool is_worker; + /** The worker ID if the client is a worker and the driver ID if the client is + * a driver. */ + WorkerID client_id; /** A pointer to the task object that is currently running on this client. If * no task is running on the worker, this will be NULL. This is used to * update the task table. */ diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc index 8ace28221..60d2f87bc 100644 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ b/src/local_scheduler/test/local_scheduler_tests.cc @@ -64,9 +64,7 @@ static void register_clients(int num_mock_workers, LocalSchedulerMock *mock) { for (int i = 0; i < num_mock_workers; ++i) { new_client_connection(mock->loop, mock->local_scheduler_fd, (void *) mock->local_scheduler_state, 0); - - LocalSchedulerClient *worker = mock->local_scheduler_state->workers[i]; - + LocalSchedulerClient *worker = mock->local_scheduler_state->workers.back(); process_message(mock->local_scheduler_state->loop, worker->sock, worker, 0); } } @@ -644,7 +642,7 @@ TEST start_kill_workers_test(void) { ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), num_workers); /* Make sure that the new worker registers its process ID. */ - worker = local_scheduler->local_scheduler_state->workers[num_workers - 1]; + worker = local_scheduler->local_scheduler_state->workers.back(); process_message(local_scheduler->local_scheduler_state->loop, worker->sock, worker, 0); ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); diff --git a/test/jenkins_tests/multi_node_docker_test.py b/test/jenkins_tests/multi_node_docker_test.py index 2a26df6c4..c7ccfac7d 100644 --- a/test/jenkins_tests/multi_node_docker_test.py +++ b/test/jenkins_tests/multi_node_docker_test.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import argparse +import numpy as np import os import re import subprocess @@ -85,8 +86,8 @@ class DockerRunner(object): else: return m.group(1) - def _start_head_node(self, docker_image, mem_size, shm_size, - development_mode): + def _start_head_node(self, docker_image, mem_size, shm_size, num_cpus, + num_gpus, development_mode): """Start the Ray head node inside a docker container.""" mem_arg = ["--memory=" + mem_size] if mem_size else [] shm_arg = ["--shm-size=" + shm_size] if shm_size else [] @@ -94,10 +95,15 @@ class DockerRunner(object): "{}:{}".format(os.path.dirname(os.path.realpath(__file__)), "/ray/test/jenkins_tests")] if development_mode else []) - proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg + - volume_arg + - [docker_image, "/ray/scripts/start_ray.sh", - "--head", "--redis-port=6379"], + + command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + + [docker_image, "/ray/scripts/start_ray.sh", "--head", + "--redis-port=6379", + "--num-cpus={}".format(num_cpus), + "--num-gpus={}".format(num_gpus)]) + print("Starting head node with command:{}".format(command)) + + proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) container_id = self._get_container_id(stdout_data) @@ -105,33 +111,34 @@ class DockerRunner(object): raise RuntimeError("Failed to find container ID.") self.head_container_id = container_id self.head_container_ip = self._get_container_ip(container_id) - print("start_node", {"container_id": container_id, - "is_head": True, - "shm_size": shm_size, - "ip_address": self.head_container_ip}) - return container_id - def _start_worker_node(self, docker_image, mem_size, shm_size): + def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus, + num_gpus, development_mode): """Start a Ray worker node inside a docker container.""" mem_arg = ["--memory=" + mem_size] if mem_size else [] shm_arg = ["--shm-size=" + shm_size] if shm_size else [] - proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg + - ["--shm-size=" + shm_size, docker_image, - "/ray/scripts/start_ray.sh", - "--redis-address={:s}:6379".format( - self.head_container_ip)], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + volume_arg = (["-v", + "{}:{}".format(os.path.dirname(os.path.realpath(__file__)), + "/ray/test/jenkins_tests")] + if development_mode else []) + command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + + ["--shm-size=" + shm_size, docker_image, + "/ray/scripts/start_ray.sh", + "--redis-address={:s}:6379".format(self.head_container_ip), + "--num-cpus={}".format(num_cpus), + "--num-gpus={}".format(num_gpus)]) + print("Starting worker node with command:{}".format(command)) + proc = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) stdout_data, _ = wait_for_output(proc) container_id = self._get_container_id(stdout_data) if container_id is None: raise RuntimeError("Failed to find container id") self.worker_container_ids.append(container_id) - print("start_node", {"container_id": container_id, - "is_head": False, - "shm_size": shm_size}) - def start_ray(self, docker_image, mem_size, shm_size, num_nodes, - development_mode): + def start_ray(self, docker_image=None, mem_size=None, shm_size=None, + num_nodes=None, num_cpus=None, num_gpus=None, + development_mode=None): """Start a Ray cluster within docker. This starts one docker container running the head node and num_nodes - 1 @@ -146,15 +153,23 @@ class DockerRunner(object): with. This will be passed into `docker run` as the `--shm-size` flag. num_nodes: The number of nodes to use in the cluster (this counts the head node as well). + num_cpus: A list of the number of CPUs to start each node with. + num_gpus: A list of the number of GPUs to start each node with. development_mode: True if you want to mount the local copy of test/jenkins_test on the head node so we can avoid rebuilding docker images during development. """ + assert len(num_cpus) == num_nodes + assert len(num_gpus) == num_nodes + # Launch the head node. - self._start_head_node(docker_image, mem_size, shm_size, development_mode) + self._start_head_node(docker_image, mem_size, shm_size, num_cpus[0], + num_gpus[0], development_mode) # Start the worker nodes. - for _ in range(num_nodes - 1): - self._start_worker_node(docker_image, mem_size, shm_size) + for i in range(num_nodes - 1): + self._start_worker_node(docker_image, mem_size, shm_size, + num_cpus[1 + i], num_gpus[1 + i], + development_mode) def _stop_node(self, container_id): """Stop a node in the Ray cluster.""" @@ -181,32 +196,51 @@ class DockerRunner(object): for container_id in self.worker_container_ids: self._stop_node(container_id) - def run_test(self, test_script, run_in_docker=False): + def run_test(self, test_script, num_drivers, driver_locations=None): """Run a test script. Run a test using the Ray cluster. Args: test_script: The test script to run. - run_in_docker: If true then the test script will be run in a docker - container. If false, it will be run regularly. + num_drivers: The number of copies of the test script to run. + driver_locations: A list of the indices of the containers that the + different copies of the test script should be run on. If this is None, + then the containers will be chosen randomly. Returns: A dictionary with information about the test script run. """ - print("Starting to run test script {}.".format(test_script)) - proc = subprocess.Popen(["docker", "exec", self.head_container_id, - "/bin/bash", "-c", - "RAY_REDIS_ADDRESS={}:6379 " - "python {}".format(self.head_container_ip, - test_script)], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout_data, stderr_data = wait_for_output(proc) - print("STDOUT:") - print(stdout_data) - print("STDERR:") - print(stderr_data) - return {"success": proc.returncode == 0, "return_code": proc.returncode} + all_container_ids = [self.head_container_id] + self.worker_container_ids + if driver_locations is None: + driver_locations = [np.random.randint(0, len(all_container_ids)) + for _ in range(num_drivers)] + + # Start the different drivers. + driver_processes = [] + for i in range(len(driver_locations)): + # Get the container ID to run the ith driver in. + container_id = all_container_ids[driver_locations[i]] + command = ["docker", "exec", container_id, "/bin/bash", "-c", + ("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python {}" + .format(self.head_container_ip, i, test_script))] + print("Starting driver with command {}.".format(test_script)) + # Start the driver. + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + driver_processes.append(p) + + # Wait for the drivers to finish. + results = [] + for p in driver_processes: + stdout_data, stderr_data = wait_for_output(p) + print("STDOUT:") + print(stdout_data) + print("STDERR:") + print(stderr_data) + results.append({"success": p.returncode == 0, + "return_code": p.returncode}) + return results if __name__ == "__main__": @@ -218,23 +252,49 @@ if __name__ == "__main__": parser.add_argument("--shm-size", default="1G", help="shared memory size") parser.add_argument("--num-nodes", default=1, type=int, help="number of nodes to use in the cluster") + parser.add_argument("--num-cpus", type=str, + help=("a comma separated list of values representing " + "the number of CPUs to start each node with")) + parser.add_argument("--num-gpus", type=str, + help=("a comma separated list of values representing " + "the number of GPUs to start each node with")) + parser.add_argument("--num-drivers", default=1, type=int, + help="number of drivers to run") + parser.add_argument("--driver-locations", type=str, + help=("a comma separated list of indices of the " + "containers to run the drivers in")) parser.add_argument("--test-script", required=True, help="test script") parser.add_argument("--development-mode", action="store_true", help="use local copies of the test scripts") args = parser.parse_args() + # Parse the number of CPUs and GPUs to use for each worker. + num_nodes = args.num_nodes + num_cpus = ([int(i) for i in args.num_cpus.split(",")] + if args.num_cpus is not None else num_nodes * [10]) + num_gpus = ([int(i) for i in args.num_gpus.split(",")] + if args.num_gpus is not None else num_nodes * [0]) + d = DockerRunner() - d.start_ray(mem_size=args.mem_size, shm_size=args.shm_size, - num_nodes=args.num_nodes, docker_image=args.docker_image, + d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size, + shm_size=args.shm_size, num_nodes=num_nodes, + num_cpus=num_cpus, num_gpus=num_gpus, development_mode=args.development_mode) try: - run_result = d.run_test(args.test_script) + run_results = d.run_test(args.test_script, args.num_drivers, + driver_locations=args.driver_locations) finally: d.stop_ray() - if "success" in run_result and run_result["success"]: - print("RESULT: Test {} succeeded.".format(args.test_script)) - sys.exit(0) - else: - print("RESULT: Test {} failed.".format(args.test_script)) + any_failed = False + for run_result in run_results: + if "success" in run_result and run_result["success"]: + print("RESULT: Test {} succeeded.".format(args.test_script)) + else: + print("RESULT: Test {} failed.".format(args.test_script)) + any_failed = True + + if any_failed: sys.exit(1) + else: + sys.exit(0) diff --git a/test/jenkins_tests/multi_node_tests/many_drivers_test.py b/test/jenkins_tests/multi_node_tests/many_drivers_test.py new file mode 100644 index 000000000..09ef13ecd --- /dev/null +++ b/test/jenkins_tests/multi_node_tests/many_drivers_test.py @@ -0,0 +1,80 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time + +import ray +from ray.test.multi_node_tests import (_wait_for_nodes_to_join, + _broadcast_event, + _wait_for_event) + +# This test should be run with 5 nodes, which have 0, 0, 5, 6, and 50 GPUs for +# a total of 61 GPUs. It should be run with a large number of drivers (e.g., +# 100). At most 10 drivers will run at a time, and each driver will use at most +# 5 GPUs (this is ceil(61 / 15), which guarantees that we will always be able +# to make progress). +total_num_nodes = 5 +max_concurrent_drivers = 15 +num_gpus_per_driver = 5 + + +@ray.actor(num_gpus=1) +class Actor1(object): + def __init__(self): + assert len(ray.get_gpu_ids()) == 1 + + def check_ids(self): + assert len(ray.get_gpu_ids()) == 1 + + +def driver(redis_address, driver_index): + """The script for driver 0. + + This driver should create five actors that each use one GPU and some actors + that use no GPUs. After a while, it should exit. + """ + ray.init(redis_address=redis_address) + + # Wait for all the nodes to join the cluster. + _wait_for_nodes_to_join(total_num_nodes) + + # Limit the number of drivers running concurrently. + for i in range(driver_index - max_concurrent_drivers + 1): + _wait_for_event("DRIVER_{}_DONE".format(i), redis_address) + + def try_to_create_actor(actor_class, timeout=100): + # Try to create an actor, but allow failures while we wait for the monitor + # to release the resources for the removed drivers. + start_time = time.time() + while time.time() - start_time < timeout: + try: + actor = actor_class() + except Exception as e: + time.sleep(0.1) + else: + return actor + # If we are here, then we timed out while looping. + raise Exception("Timed out while trying to create actor.") + + # Create some actors that require one GPU. + actors_one_gpu = [] + for _ in range(num_gpus_per_driver): + actors_one_gpu.append(try_to_create_actor(Actor1)) + + for _ in range(100): + ray.get([actor.check_ids() for actor in actors_one_gpu]) + + _broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address) + + +if __name__ == "__main__": + driver_index = int(os.environ["RAY_DRIVER_INDEX"]) + redis_address = os.environ["RAY_REDIS_ADDRESS"] + print("Driver {} started at {}.".format(driver_index, time.time())) + + # In this test, all drivers will run the same script. + driver(redis_address, driver_index) + + print("Driver {} finished at {}.".format(driver_index, time.time())) diff --git a/test/jenkins_tests/multi_node_tests/remove_driver_test.py b/test/jenkins_tests/multi_node_tests/remove_driver_test.py new file mode 100644 index 000000000..35ffc6d65 --- /dev/null +++ b/test/jenkins_tests/multi_node_tests/remove_driver_test.py @@ -0,0 +1,153 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time + +import ray +from ray.test.multi_node_tests import (_wait_for_nodes_to_join, + _broadcast_event, + _wait_for_event) + +# This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a +# total of 10 GPUs. It shoudl be run with 3 drivers. +total_num_nodes = 5 + + +@ray.actor +class Actor0(object): + def __init__(self): + assert len(ray.get_gpu_ids()) == 0 + + def check_ids(self): + assert len(ray.get_gpu_ids()) == 0 + + +@ray.actor(num_gpus=1) +class Actor1(object): + def __init__(self): + assert len(ray.get_gpu_ids()) == 1 + + def check_ids(self): + assert len(ray.get_gpu_ids()) == 1 + + +@ray.actor(num_gpus=2) +class Actor2(object): + def __init__(self): + assert len(ray.get_gpu_ids()) == 2 + + def check_ids(self): + assert len(ray.get_gpu_ids()) == 2 + + +def driver_0(redis_address): + """The script for driver 0. + + This driver should create five actors that each use one GPU and some actors + that use no GPUs. After a while, it should exit. + """ + ray.init(redis_address=redis_address) + + # Wait for all the nodes to join the cluster. + _wait_for_nodes_to_join(total_num_nodes) + + # Create some actors that require one GPU. + actors_one_gpu = [Actor1() for _ in range(5)] + # Create some actors that don't require any GPUs. + actors_no_gpus = [Actor0() for _ in range(5)] + + for _ in range(1000): + ray.get([actor.check_ids() for actor in actors_one_gpu]) + ray.get([actor.check_ids() for actor in actors_no_gpus]) + + _broadcast_event("DRIVER_0_DONE", redis_address) + + +def driver_1(redis_address): + """The script for driver 1. + + This driver should create one actor that uses two GPUs, three actors that + each use one GPU (the one requiring two must be created first), and some + actors that don't use any GPUs. After a while, it should exit. + """ + ray.init(redis_address=redis_address) + + # Wait for all the nodes to join the cluster. + _wait_for_nodes_to_join(total_num_nodes) + + # Create an actor that requires two GPUs. + actors_two_gpus = [Actor2() for _ in range(1)] + # Create some actors that require one GPU. + actors_one_gpu = [Actor1() for _ in range(3)] + # Create some actors that don't require any GPUs. + actors_no_gpus = [Actor0() for _ in range(5)] + + for _ in range(1000): + ray.get([actor.check_ids() for actor in actors_two_gpus]) + ray.get([actor.check_ids() for actor in actors_one_gpu]) + ray.get([actor.check_ids() for actor in actors_no_gpus]) + + _broadcast_event("DRIVER_1_DONE", redis_address) + + +def driver_2(redis_address): + """The script for driver 2. + + This driver should wait for the first two drivers to finish. Then it should + create some actors that use a total of ten GPUs. + """ + ray.init(redis_address=redis_address) + + _wait_for_event("DRIVER_0_DONE", redis_address) + _wait_for_event("DRIVER_1_DONE", redis_address) + + def try_to_create_actor(actor_class, timeout=20): + # Try to create an actor, but allow failures while we wait for the monitor + # to release the resources for the removed drivers. + start_time = time.time() + while time.time() - start_time < timeout: + try: + actor = actor_class() + except Exception as e: + time.sleep(0.1) + else: + return actor + # If we are here, then we timed out while looping. + raise Exception("Timed out while trying to create actor.") + + # Create some actors that require two GPUs. + actors_two_gpus = [] + for _ in range(3): + actors_two_gpus.append(try_to_create_actor(Actor2)) + # Create some actors that require one GPU. + actors_one_gpu = [] + for _ in range(4): + actors_one_gpu.append(try_to_create_actor(Actor1)) + # Create some actors that don't require any GPUs. + actors_no_gpus = [Actor0() for _ in range(5)] + + for _ in range(1000): + ray.get([actor.check_ids() for actor in actors_two_gpus]) + ray.get([actor.check_ids() for actor in actors_one_gpu]) + ray.get([actor.check_ids() for actor in actors_no_gpus]) + + _broadcast_event("DRIVER_2_DONE", redis_address) + + +if __name__ == "__main__": + driver_index = int(os.environ["RAY_DRIVER_INDEX"]) + redis_address = os.environ["RAY_REDIS_ADDRESS"] + print("Driver {} started at {}.".format(driver_index, time.time())) + + if driver_index == 0: + driver_0(redis_address) + elif driver_index == 1: + driver_1(redis_address) + elif driver_index == 2: + driver_2(redis_address) + else: + raise Exception("This code should be unreachable.") + + print("Driver {} finished at {}.".format(driver_index, time.time())) diff --git a/test/jenkins_tests/multi_node_tests/test_0.py b/test/jenkins_tests/multi_node_tests/test_0.py index 6eb27fdee..8710a57bd 100644 --- a/test/jenkins_tests/multi_node_tests/test_0.py +++ b/test/jenkins_tests/multi_node_tests/test_0.py @@ -15,7 +15,11 @@ def f(): if __name__ == "__main__": - ray.init(redis_address=os.environ["RAY_REDIS_ADDRESS"]) + driver_index = int(os.environ["RAY_DRIVER_INDEX"]) + redis_address = os.environ["RAY_REDIS_ADDRESS"] + print("Driver {} started at {}.".format(driver_index, time.time())) + + ray.init(redis_address=redis_address) # Check that tasks are scheduled on all nodes. num_attempts = 30 for i in range(num_attempts): @@ -26,3 +30,5 @@ if __name__ == "__main__": if len(counts) == 5: break assert len(counts) == 5 + + print("Driver {} finished at {}.".format(driver_index, time.time())) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 9e7e6c161..ded001b71 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -13,6 +13,20 @@ python $ROOT_DIR/multi_node_docker_test.py \ --num-nodes=5 \ --test-script=/ray/test/jenkins_tests/multi_node_tests/test_0.py +python $ROOT_DIR/multi_node_docker_test.py \ + --docker-image=$DOCKER_SHA \ + --num-nodes=5 \ + --num-gpus=0,1,2,3,4 \ + --num-drivers=3 \ + --test-script=/ray/test/jenkins_tests/multi_node_tests/remove_driver_test.py + +python $ROOT_DIR/multi_node_docker_test.py \ + --docker-image=$DOCKER_SHA \ + --num-nodes=5 \ + --num-gpus=0,0,5,6,50 \ + --num-drivers=100 \ + --test-script=/ray/test/jenkins_tests/multi_node_tests/many_drivers_test.py + python $ROOT_DIR/multi_node_docker_test.py \ --docker-image=$DOCKER_SHA \ --num-nodes=1 \ diff --git a/test/runtest.py b/test/runtest.py index c0775edc5..c1f79ae8f 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1433,8 +1433,7 @@ class GlobalStateAPI(unittest.TestCase): client_table = ray.global_state.client_table() node_ip_address = ray.worker.global_worker.node_ip_address - self.assertEqual(len(client_table[node_ip_address]), 2) - self.assertEqual(len(client_table[":"]), 1) + self.assertEqual(len(client_table[node_ip_address]), 3) manager_client = [c for c in client_table[node_ip_address] if c["ClientType"] == "plasma_manager"][0]