Clean up when a driver disconnects. (#462)

* Clean up state when drivers exit.

* Remove unnecessary field in ActorMapEntry struct.

* Have monitor release GPU resources in Redis when driver exits.

* Enable multiple drivers in multi-node tests and test driver cleanup.

* Make redis GPU allocation a redis transaction and small cleanups.

* Fix multi-node test.

* Small cleanups.

* Make global scheduler take node_ip_address so it appears in the right place in the client table.

* Cleanups.

* Fix linting and cleanups in local scheduler.

* Fix removed_driver_test.

* Fix bug related to vector -> list.

* Fix linting.

* Cleanup.

* Fix multi node tests.

* Fix jenkins tests.

* Add another multi node test with many drivers.

* Fix linting.

* Make the actor creation notification a flatbuffer message.

* Revert "Make the actor creation notification a flatbuffer message."

This reverts commit af99099c8084dbf9177fb4e34c0c9b1a12c78f39.

* Add comment explaining flatbuffer problems.
This commit is contained in:
Robert Nishihara 2017-04-24 18:10:21 -07:00 committed by Philipp Moritz
parent 8194b71f32
commit 0ac125e9b2
31 changed files with 1119 additions and 168 deletions

View file

@ -7,13 +7,14 @@ import inspect
import json
import numpy as np
import random
import redis
import traceback
import ray.local_scheduler
import ray.pickling as pickling
import ray.signature as signature
import ray.worker
import ray.experimental.state as state
from ray.utils import binary_to_hex, hex_to_binary
# This is a variable used by each actor to indicate the IDs of the GPUs that
# the worker is currently allowed to use.
@ -105,6 +106,72 @@ def fetch_and_register_actor(key, worker):
# the actor.
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
Args:
num_gpus: The number of GPUs to acquire.
driver_id: The ID of the driver responsible for creating the actor.
local_scheduler: Information about the local scheduler.
Returns:
A list of the GPU IDs that were successfully acquired. This should have
length either equal to num_gpus or equal to 0.
"""
local_scheduler_id = local_scheduler["DBClientID"]
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
gpus_to_acquire = []
# Attempt to acquire GPU IDs atomically.
with worker.redis_client.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the multi/exec
# block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Figure out which GPUs are currently in use.
result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
all_gpu_ids_in_use = []
for key in gpus_in_use:
all_gpu_ids_in_use += gpus_in_use[key]
assert len(all_gpu_ids_in_use) <= local_scheduler_total_gpus
assert len(set(all_gpu_ids_in_use)) == len(all_gpu_ids_in_use)
pipe.multi()
if local_scheduler_total_gpus - len(all_gpu_ids_in_use) >= num_gpus:
# There are enough available GPUs, so try to reserve some.
all_gpu_ids = set(range(local_scheduler_total_gpus))
for gpu_id in all_gpu_ids_in_use:
all_gpu_ids.remove(gpu_id)
gpus_to_acquire = list(all_gpu_ids)[:num_gpus]
# Use the hex driver ID so that the dictionary is JSON serializable.
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex not in gpus_in_use:
gpus_in_use[driver_id_hex] = []
gpus_in_use[driver_id_hex] += gpus_to_acquire
# Stick the updated GPU IDs back in Redis
pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use))
pipe.execute()
# If a WatchError is not raised, then the operations should have gone
# through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the time we
# started WATCHing it and the pipeline's execution. We should just
# retry.
gpus_to_acquire = []
continue
return gpus_to_acquire
def select_local_scheduler(local_schedulers, num_gpus, worker):
"""Select a local scheduler to assign this actor to.
@ -121,42 +188,33 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
Exception: An exception is raised if no local scheduler can be found with
sufficient resources.
"""
# TODO(rkn): We should change this method to have a list of GPU IDs that we
# pop from and push to. The current implementation is not compatible with
# actors releasing GPU resources.
driver_id = worker.task_driver_id.id()
if num_gpus == 0:
local_scheduler_id = random.choice(local_schedulers)[b"ray_client_id"]
gpu_ids = []
local_scheduler_id = hex_to_binary(
random.choice(local_schedulers)["DBClientID"])
gpus_aquired = []
else:
# All of this logic is for finding a local scheduler that has enough
# available GPUs.
local_scheduler_id = None
# Loop through all of the local schedulers.
for local_scheduler in local_schedulers:
# See if there are enough available GPUs on this local scheduler.
local_scheduler_total_gpus = int(float(
local_scheduler[b"num_gpus"].decode("ascii")))
gpus_in_use = worker.redis_client.hget(local_scheduler[b"ray_client_id"],
b"gpus_in_use")
gpus_in_use = 0 if gpus_in_use is None else int(gpus_in_use)
if gpus_in_use + num_gpus <= local_scheduler_total_gpus:
# Attempt to reserve some GPUs for this actor.
new_gpus_in_use = worker.redis_client.hincrby(
local_scheduler[b"ray_client_id"], b"gpus_in_use", num_gpus)
if new_gpus_in_use > local_scheduler_total_gpus:
# If we failed to reserve the GPUs, undo the increment.
worker.redis_client.hincrby(local_scheduler[b"ray_client_id"],
b"gpus_in_use", num_gpus)
else:
# We succeeded at reserving the GPUs, so we are done.
local_scheduler_id = local_scheduler[b"ray_client_id"]
gpu_ids = list(range(new_gpus_in_use - num_gpus, new_gpus_in_use))
break
# Try to reserve enough GPUs on this local scheduler.
gpus_aquired = attempt_to_reserve_gpus(num_gpus, driver_id,
local_scheduler, worker)
if len(gpus_aquired) == num_gpus:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
else:
# We should have either acquired as many GPUs as we need or none.
assert len(gpus_aquired) == 0
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs to create this "
"actor. The local scheduler information is {}."
.format(local_schedulers))
return local_scheduler_id, gpu_ids
return local_scheduler_id, gpus_aquired
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
@ -183,13 +241,23 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
worker.function_properties[driver_id][function_id] = (1, num_cpus,
num_gpus)
# Get a list of the local schedulers from the client table.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
# Select a local scheduler for the actor.
local_schedulers = state.get_local_schedulers(worker)
local_scheduler_id, gpu_ids = select_local_scheduler(local_schedulers,
num_gpus, worker)
# Really we should encode this message as a flatbuffer object. However, we're
# having trouble getting that to work. It almost works, but in Python 2.7,
# builder.CreateString fails on byte strings that contain characters outside
# range(128).
worker.redis_client.publish("actor_notifications",
actor_id.id() + local_scheduler_id)
actor_id.id() + driver_id + local_scheduler_id)
d = {"driver_id": driver_id,
"actor_id": actor_id.id(),

View file

@ -2,12 +2,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import pickle
import redis
import sys
import ray.local_scheduler
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
# Import flatbuffer bindings.
from ray.core.generated.TaskInfo import TaskInfo
@ -36,40 +35,6 @@ task_state_mapping = {
}
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def binary_to_object_id(binary_object_id):
return ray.local_scheduler.ObjectID(binary_object_id)
def binary_to_hex(identifier):
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)
def get_local_schedulers(worker):
local_schedulers = []
for client in worker.redis_client.keys("CL:*"):
client_info = worker.redis_client.hgetall(client)
if b"client_type" not in client_info:
continue
if client_info[b"client_type"] == b"local_scheduler":
local_schedulers.append(client_info)
return local_schedulers
class GlobalState(object):
"""A class used to interface with the Ray control state.

View file

@ -7,13 +7,15 @@ import subprocess
import time
def start_global_scheduler(redis_address, use_valgrind=False,
def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will run
on.
use_valgrind (bool): True if the global scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started inside
@ -31,7 +33,9 @@ def start_global_scheduler(redis_address, use_valgrind=False,
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable, "-r", redis_address]
command = [global_scheduler_executable,
"-r", redis_address,
"-h", node_ip_address]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",

View file

@ -71,7 +71,7 @@ class TestGlobalScheduler(unittest.TestCase):
port=redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, use_valgrind=USE_VALGRIND)
redis_address, node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []

View file

@ -4,10 +4,12 @@ from __future__ import print_function
import argparse
from collections import Counter
import json
import logging
import redis
import time
import ray
from ray.services import get_ip_address
from ray.services import get_port
@ -15,6 +17,7 @@ from ray.services import get_port
from ray.core.generated.SubscribeToDBClientTableReply \
import SubscribeToDBClientTableReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
# These variables must be kept in sync with the C codebase.
# common/common.h
@ -26,6 +29,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
TASK_STATUS_LOST = 32
# common/state/redis.cc
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# common/redis_module/ray_redis_module.cc
TASK_PREFIX = "TT:"
OBJECT_PREFIX = "OL:"
@ -215,6 +219,65 @@ class Monitor(object):
# manager.
self.live_plasma_managers[db_client_id] = 0
def driver_removed_handler(self, channel, data):
"""Handle a notification that a driver has been removed.
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed.".format(driver_id))
# Get a list of the local schedulers.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
# Release any GPU resources that have been reserved for this driver in
# Redis.
for local_scheduler in local_schedulers:
if int(local_scheduler["NumGPUs"]) > 0:
local_scheduler_id = local_scheduler["DBClientID"]
returned_gpu_ids = []
# Perform a transaction to return the GPUs.
with self.redis.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the
# multi/exec block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
result = pipe.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
driver_id_hex = ray.utils.binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
returned_gpu_ids = gpus_in_use.pop(driver_id_hex)
pipe.multi()
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
pipe.execute()
# If a WatchError is not raise, then the operations should have
# gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the
# time we started WATCHing it and the pipeline's execution. We
# should just retry.
continue
log.info("Driver {} is returning GPU IDs {} to local scheduler {}."
.format(driver_id, returned_gpu_ids, local_scheduler_id))
def process_messages(self):
"""Process all messages ready in the subscription channels.
@ -244,6 +307,12 @@ class Monitor(object):
assert(self.subscribed[channel])
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
assert(self.subscribed[channel])
# The message was a notification that a driver was removed.
message_handler = self.driver_removed_handler
else:
raise Exception("This code should be unreachable.")
# Call the handler.
assert(message_handler is not None)
@ -258,6 +327,7 @@ class Monitor(object):
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel. This
@ -326,5 +396,8 @@ if __name__ == "__main__":
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
# Initialize the global state.
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
monitor = Monitor(redis_ip_address, redis_port)
monitor.run()

View file

@ -372,7 +372,7 @@ def start_global_scheduler(redis_address, node_ip_address, stdout_file=None,
this process will be killed by services.cleanup() when the Python process
that imported services exits.
"""
p = global_scheduler.start_global_scheduler(redis_address,
p = global_scheduler.start_global_scheduler(redis_address, node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
@ -767,7 +767,10 @@ def start_ray_processes(address_info=None,
if num_workers is not None:
workers_per_local_scheduler = num_local_schedulers * [num_workers]
else:
workers_per_local_scheduler = num_local_schedulers * [psutil.cpu_count()]
workers_per_local_scheduler = []
for cpus in num_cpus:
workers_per_local_scheduler.append(cpus if cpus is not None
else psutil.cpu_count())
if address_info is None:
address_info = {}

View file

@ -0,0 +1,86 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import redis
import time
import ray
EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY"
"""This key is used internally within this file for coordinating drivers."""
def _wait_for_nodes_to_join(num_nodes, timeout=20):
"""Wait until the nodes have joined the cluster.
This will wait until exactly num_nodes have joined the cluster and each node
has a local scheduler and a plasma manager.
Args:
num_nodes: The number of nodes to wait for.
timeout: The amount of time in seconds to wait before failing.
Raises:
Exception: An exception is raised if too many nodes join the cluster or if
the timeout expires while we are waiting.
"""
start_time = time.time()
while time.time() - start_time < timeout:
client_table = ray.global_state.client_table()
num_ready_nodes = len(client_table)
if num_ready_nodes == num_nodes:
ready = True
# Check that for each node, a local scheduler and a plasma manager are
# present.
for ip_address, clients in client_table.items():
client_types = [client["ClientType"] for client in clients]
if "local_scheduler" not in client_types:
ready = False
if "plasma_manager" not in client_types:
ready = False
if ready:
return
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
"expecting {} nodes.".format(num_ready_nodes, num_nodes))
time.sleep(0.1)
# If we get here then we timed out.
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
"nodes have joined so far.".format(num_ready_nodes,
num_nodes))
def _broadcast_event(event_name, redis_address):
"""Broadcast an event.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for synchronization.
This is used to synchronize drivers for the multi-node tests.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
redis_client.rpush(EVENT_KEY, event_name)
def _wait_for_event(event_name, redis_address, extra_buffer=1):
"""Block until an event has been broadcast.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for synchronization.
extra_buffer: An amount of time in seconds to wait after the event.
This is used to synchronize drivers for the multi-node tests.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
while True:
event_names = redis_client.lrange(EVENT_KEY, 0, -1)
if event_name.encode("ascii") in event_names:
break
time.sleep(extra_buffer)

31
python/ray/utils.py Normal file
View file

@ -0,0 +1,31 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import sys
import ray.local_scheduler
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def binary_to_object_id(binary_object_id):
return ray.local_scheduler.ObjectID(binary_object_id)
def binary_to_hex(identifier):
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)

View file

@ -61,6 +61,7 @@ add_library(common STATIC
state/object_table.cc
state/task_table.cc
state/db_client_table.cc
state/driver_table.cc
state/actor_notification_table.cc
state/local_scheduler_table.cc
state/error_table.cc

View file

@ -47,6 +47,10 @@ bool DBClientID_equal(DBClientID first_id, DBClientID second_id) {
return UNIQUE_ID_EQ(first_id, second_id);
}
bool WorkerID_equal(WorkerID first_id, WorkerID second_id) {
return UNIQUE_ID_EQ(first_id, second_id);
}
char *ObjectID_to_string(ObjectID obj_id, char *id_string, int id_length) {
CHECK(id_length >= ID_STRING_SIZE);
static const char hex[] = "0123456789abcdef";

View file

@ -153,7 +153,9 @@ extern const UniqueID NIL_ID;
UniqueID globally_unique_id(void);
#define NIL_OBJECT_ID NIL_ID
#define NIL_WORKER_ID NIL_ID
/** The object ID is the type used to identify objects. */
typedef UniqueID ObjectID;
#ifdef __cplusplus
@ -202,6 +204,18 @@ bool ObjectID_equal(ObjectID first_id, ObjectID second_id);
*/
bool ObjectID_is_nil(ObjectID id);
/** The worker ID is the ID of a worker or driver. */
typedef UniqueID WorkerID;
/**
* Compare two worker IDs.
*
* @param first_id The first worker ID to compare.
* @param second_id The first worker ID to compare.
* @return True if the worker IDs are the same and false otherwise.
*/
bool WorkerID_equal(WorkerID first_id, WorkerID second_id);
typedef UniqueID DBClientID;
/**

View file

@ -139,3 +139,8 @@ table ResultTableReply {
}
root_type ResultTableReply;
table DriverTableMessage {
// The driver ID of the driver that died.
driver_id: string;
}

View file

@ -5,20 +5,16 @@
#include "db.h"
#include "table.h"
typedef struct {
/** The ID of the actor. */
ActorID actor_id;
/** The ID of the local scheduler that is responsible for the actor. */
DBClientID local_scheduler_id;
} ActorInfo;
/*
* ==== Subscribing to the actor notification table ====
*/
/* Callback for subscribing to the local scheduler table. */
typedef void (*actor_notification_table_subscribe_callback)(ActorInfo info,
void *user_context);
typedef void (*actor_notification_table_subscribe_callback)(
ActorID actor_id,
WorkerID driver_id,
DBClientID local_scheduler_id,
void *user_context);
/**
* Register a callback to process actor notification events.

View file

@ -0,0 +1,21 @@
#include "driver_table.h"
#include "redis.h"
void driver_table_subscribe(DBHandle *db_handle,
driver_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry) {
DriverTableSubscribeData *sub_data =
(DriverTableSubscribeData *) malloc(sizeof(DriverTableSubscribeData));
sub_data->subscribe_callback = subscribe_callback;
sub_data->subscribe_context = subscribe_context;
init_table_callback(db_handle, NIL_ID, __func__, sub_data, retry, NULL,
redis_driver_table_subscribe, NULL);
}
void driver_table_send_driver_death(DBHandle *db_handle,
WorkerID driver_id,
RetryInfo *retry) {
init_table_callback(db_handle, driver_id, __func__, NULL, retry, NULL,
redis_driver_table_send_driver_death, NULL);
}

View file

@ -0,0 +1,50 @@
#ifndef DRIVER_TABLE_H
#define DRIVER_TABLE_H
#include "db.h"
#include "table.h"
#include "task.h"
/*
* ==== Subscribing to the driver table ====
*/
/* Callback for subscribing to the driver table. */
typedef void (*driver_table_subscribe_callback)(WorkerID driver_id,
void *user_context);
/**
* Register a callback for a driver table event.
*
* @param db_handle Database handle.
* @param subscribe_callback Callback that will be called when the driver event
* happens.
* @param subscribe_context Context that will be passed into the
* subscribe_callback.
* @param retry Information about retrying the request to the database.
* @return Void.
*/
void driver_table_subscribe(DBHandle *db_handle,
driver_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry);
/* Data that is needed to register driver table subscribe callbacks with the
* state database. */
typedef struct {
driver_table_subscribe_callback subscribe_callback;
void *subscribe_context;
} DriverTableSubscribeData;
/**
* Send driver death update to all subscribers.
*
* @param db_handle Database handle.
* @param driver_id The ID of the driver that died.
* @param retry Information about retrying the request to the database.
*/
void driver_table_send_driver_death(DBHandle *db_handle,
WorkerID driver_id,
RetryInfo *retry);
#endif /* DRIVER_TABLE_H */

View file

@ -17,6 +17,7 @@ extern "C" {
#include "db.h"
#include "db_client_table.h"
#include "actor_notification_table.h"
#include "driver_table.h"
#include "local_scheduler_table.h"
#include "object_table.h"
#include "task.h"
@ -1089,6 +1090,90 @@ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) {
}
}
void redis_driver_table_subscribe_callback(redisAsyncContext *c,
void *r,
void *privdata) {
REDIS_CALLBACK_HEADER(db, callback_data, r);
redisReply *reply = (redisReply *) r;
CHECK(reply->type == REDIS_REPLY_ARRAY);
CHECK(reply->elements == 3);
redisReply *message_type = reply->element[0];
LOG_DEBUG("Driver table subscribe callback, message %s", message_type->str);
if (strcmp(message_type->str, "message") == 0) {
/* Handle a driver heartbeat. Parse the payload and call the subscribe
* callback. */
auto message =
flatbuffers::GetRoot<DriverTableMessage>(reply->element[2]->str);
/* Extract the client ID. */
WorkerID driver_id = from_flatbuf(message->driver_id());
/* Call the subscribe callback. */
DriverTableSubscribeData *data =
(DriverTableSubscribeData *) callback_data->data;
if (data->subscribe_callback) {
data->subscribe_callback(driver_id, data->subscribe_context);
}
} else if (strcmp(message_type->str, "subscribe") == 0) {
/* The reply for the initial SUBSCRIBE command. */
CHECK(callback_data->done_callback == NULL);
/* If the initial SUBSCRIBE was successful, clean up the timer, but don't
* destroy the callback data. */
event_loop_remove_timer(db->loop, callback_data->timer_id);
} else {
LOG_FATAL("Unexpected reply type from driver subscribe.");
}
}
void redis_driver_table_subscribe(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
int status = redisAsyncCommand(
db->sub_context, redis_driver_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE driver_deaths");
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in redis_driver_table_subscribe");
}
}
void redis_driver_table_send_driver_death_callback(redisAsyncContext *c,
void *r,
void *privdata) {
REDIS_CALLBACK_HEADER(db, callback_data, r);
redisReply *reply = (redisReply *) r;
CHECK(reply->type == REDIS_REPLY_INTEGER);
LOG_DEBUG("%" PRId64 " subscribers received this publish.\n", reply->integer);
/* At the very least, the local scheduler that publishes this message should
* also receive it. */
CHECK(reply->integer >= 1);
CHECK(callback_data->done_callback == NULL);
/* Clean up the timer and callback. */
destroy_timer_callback(db->loop, callback_data);
}
void redis_driver_table_send_driver_death(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
WorkerID driver_id = callback_data->id;
/* Create a flatbuffer object to serialize and publish. */
flatbuffers::FlatBufferBuilder fbb;
/* Create the flatbuffers message. */
auto message = CreateDriverTableMessage(fbb, to_flatbuf(fbb, driver_id));
fbb.Finish(message);
int status = redisAsyncCommand(
db->context, redis_driver_table_send_driver_death_callback,
(void *) callback_data->timer_id, "PUBLISH driver_deaths %b",
fbb.GetBufferPointer(), fbb.GetSize());
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context,
"error in redis_driver_table_send_driver_death");
}
}
void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
/* NOTE(swang): We purposefully do not provide a callback, leaving the table
@ -1124,15 +1209,20 @@ void redis_actor_notification_table_subscribe_callback(redisAsyncContext *c,
redisReply *payload = reply->element[2];
ActorNotificationTableSubscribeData *data =
(ActorNotificationTableSubscribeData *) callback_data->data;
ActorInfo info;
/* The payload should be the concatenation of these two structs. */
CHECK(sizeof(info.actor_id) + sizeof(info.local_scheduler_id) ==
/* The payload should be the concatenation of three IDs. */
ActorID actor_id;
WorkerID driver_id;
DBClientID local_scheduler_id;
CHECK(sizeof(actor_id) + sizeof(driver_id) + sizeof(local_scheduler_id) ==
payload->len);
memcpy(&info.actor_id, payload->str, sizeof(info.actor_id));
memcpy(&info.local_scheduler_id, payload->str + sizeof(info.actor_id),
sizeof(info.local_scheduler_id));
memcpy(&actor_id, payload->str, sizeof(actor_id));
memcpy(&driver_id, payload->str + sizeof(actor_id), sizeof(driver_id));
memcpy(&local_scheduler_id,
payload->str + sizeof(actor_id) + sizeof(driver_id),
sizeof(local_scheduler_id));
if (data->subscribe_callback) {
data->subscribe_callback(info, data->subscribe_context);
data->subscribe_callback(actor_id, driver_id, local_scheduler_id,
data->subscribe_context);
}
} else if (strcmp(message_type->str, "subscribe") == 0) {
/* The reply for the initial SUBSCRIBE command. */

View file

@ -253,6 +253,24 @@ void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data);
*/
void redis_local_scheduler_table_send_info(TableCallbackData *callback_data);
/**
* Subscribe to updates from the driver table.
*
* @param callback_data Data structure containing redis connection and timeout
* information.
* @return Void.
*/
void redis_driver_table_subscribe(TableCallbackData *callback_data);
/**
* Publish an update to the driver table.
*
* @param callback_data Data structure containing redis connection and timeout
* information.
* @return Void.
*/
void redis_driver_table_send_driver_death(TableCallbackData *callback_data);
void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data);
/**

View file

@ -16,7 +16,6 @@ struct TaskBuilder;
#define NIL_TASK_ID NIL_ID
#define NIL_ACTOR_ID NIL_ID
#define NIL_FUNCTION_ID NIL_ID
#define NIL_WORKER_ID NIL_ID
typedef UniqueID FunctionID;

View file

@ -61,14 +61,15 @@ void assign_task_to_local_scheduler(GlobalSchedulerState *state,
}
GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
const char *node_ip_address,
const char *redis_addr,
int redis_port) {
GlobalSchedulerState *state =
(GlobalSchedulerState *) malloc(sizeof(GlobalSchedulerState));
/* Must initialize state to 0. Sets hashmap head(s) to NULL. */
memset(state, 0, sizeof(GlobalSchedulerState));
state->db =
db_connect(redis_addr, redis_port, "global_scheduler", ":", 0, NULL);
state->db = db_connect(redis_addr, redis_port, "global_scheduler",
node_ip_address, 0, NULL);
db_attach(state->db, loop, false);
utarray_new(state->local_schedulers, &local_scheduler_icd);
state->policy_state = GlobalSchedulerPolicyState_init();
@ -416,9 +417,12 @@ int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) {
return HEARTBEAT_TIMEOUT_MILLISECONDS;
}
void start_server(const char *redis_addr, int redis_port) {
void start_server(const char *node_ip_address,
const char *redis_addr,
int redis_port) {
event_loop *loop = event_loop_create();
g_state = GlobalSchedulerState_init(loop, redis_addr, redis_port);
g_state =
GlobalSchedulerState_init(loop, node_ip_address, redis_addr, redis_port);
/* TODO(rkn): subscribe to notifications from the object table. */
/* Subscribe to notifications about new local schedulers. TODO(rkn): this
* needs to also get all of the clients that registered with the database
@ -456,12 +460,17 @@ int main(int argc, char *argv[]) {
signal(SIGTERM, signal_handler);
/* IP address and port of redis. */
char *redis_addr_port = NULL;
/* The IP address of the node that this global scheduler is running on. */
char *node_ip_address = NULL;
int c;
while ((c = getopt(argc, argv, "s:m:h:p:r:")) != -1) {
while ((c = getopt(argc, argv, "h:r:")) != -1) {
switch (c) {
case 'r':
redis_addr_port = optarg;
break;
case 'h':
node_ip_address = optarg;
break;
default:
LOG_ERROR("unknown option %c", c);
exit(-1);
@ -472,8 +481,11 @@ int main(int argc, char *argv[]) {
if (!redis_addr_port ||
parse_ip_addr_port(redis_addr_port, redis_addr, &redis_port) == -1) {
LOG_ERROR(
"need to specify redis address like 127.0.0.1:6379 with -r switch");
"specify the redis address like 127.0.0.1:6379 with the -r switch");
exit(-1);
}
start_server(redis_addr, redis_port);
if (!node_ip_address) {
LOG_FATAL("specify the node IP address with the -h switch");
}
start_server(node_ip_address, redis_addr, redis_port);
}

View file

@ -18,6 +18,7 @@
#include "local_scheduler_algorithm.h"
#include "state/actor_notification_table.h"
#include "state/db.h"
#include "state/driver_table.h"
#include "state/task_table.h"
#include "state/object_table.h"
#include "state/error_table.h"
@ -67,7 +68,8 @@ int force_kill_worker(event_loop *loop, timer_id id, void *context) {
/**
* Kill a worker, if it is a child process, and clean up all of its associated
* state.
* state. Note that this function is also called on drivers, but it should not
* actually send a kill signal to drivers.
*
* @param worker A pointer to the worker we want to kill.
* @param cleanup A bool representing whether we're cleaning up the entire local
@ -89,7 +91,13 @@ void kill_worker(LocalSchedulerState *state,
CHECK(it == state->workers.end());
/* Erase the algorithm state's reference to the worker. */
handle_worker_removed(state, state->algorithm_state, worker);
if (ActorID_equal(worker->actor_id, NIL_ACTOR_ID)) {
handle_worker_removed(state, state->algorithm_state, worker);
} else {
/* Let the scheduling algorithm process the absence of this worker. */
handle_actor_worker_disconnect(state, state->algorithm_state,
worker->actor_id);
}
/* Remove the client socket from the event loop so that we don't process the
* SIGPIPE when the worker is killed. */
@ -99,6 +107,8 @@ void kill_worker(LocalSchedulerState *state,
* process, use it to send a kill signal. */
bool free_worker = true;
if (worker->is_child && worker->pid != 0) {
/* If worker is a driver, we should not enter this condition because
* worker->pid should be 0. */
if (cleanup) {
/* If we're exiting the local scheduler anyway, it's okay to force kill
* the worker immediately. Wait for the process to exit. */
@ -425,10 +435,18 @@ void update_dynamic_resources(LocalSchedulerState *state,
print_resource_info(state, spec);
}
bool is_driver_alive(LocalSchedulerState *state, WorkerID driver_id) {
return state->removed_drivers.count(driver_id) == 0;
}
void assign_task_to_worker(LocalSchedulerState *state,
TaskSpec *spec,
int64_t task_spec_size,
LocalSchedulerClient *worker) {
/* Make sure the driver for this task is still alive. */
WorkerID driver_id = TaskSpec_driver_id(spec);
CHECK(is_driver_alive(state, driver_id));
/* Construct a flatbuffer object to send to the worker. */
flatbuffers::FlatBufferBuilder fbb;
auto message =
@ -649,8 +667,15 @@ void send_client_register_reply(LocalSchedulerState *state,
void handle_client_register(LocalSchedulerState *state,
LocalSchedulerClient *worker,
const RegisterClientRequest *message) {
/* Make sure this worker hasn't already registered. */
CHECK(!worker->registered);
worker->registered = true;
worker->is_worker = message->is_worker();
CHECK(WorkerID_equal(worker->client_id, NIL_WORKER_ID));
worker->client_id = from_flatbuf(message->client_id());
/* Register the worker or driver. */
if (message->is_worker()) {
if (worker->is_worker) {
/* Update the actor mapping with the actor ID of the worker (if an actor is
* running on the worker). */
worker->pid = message->worker_pid();
@ -683,12 +708,74 @@ void handle_client_register(LocalSchedulerState *state,
state->child_pids.erase(it);
LOG_DEBUG("Found matching child pid %d", worker->pid);
}
/* If the worker is an actor that corresponds to a driver that has been
* removed, then kill the worker. */
if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) {
WorkerID driver_id = state->actor_mapping[actor_id].driver_id;
if (state->removed_drivers.count(driver_id) == 1) {
kill_worker(state, worker, false);
}
}
} else {
/* Register the driver. Currently we don't do anything here. */
}
}
/* End of the cleanup code. */
void handle_driver_removed_callback(WorkerID driver_id, void *user_context) {
LocalSchedulerState *state = (LocalSchedulerState *) user_context;
/* Kill any actors that were created by the removed driver, and kill any
* workers that are currently running tasks from the dead driver. */
auto it = state->workers.begin();
while (it != state->workers.end()) {
/* Increment the iterator by one before calling kill_worker, because
* kill_worker will invalidate the iterator. Note that this requires
* knowledge of the particular container that we are iterating over (in this
* case it is a list). */
auto next_it = it;
next_it++;
ActorID actor_id = (*it)->actor_id;
Task *task = (*it)->task_in_progress;
if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) {
/* This is an actor. */
CHECK(state->actor_mapping.count(actor_id) == 1);
if (WorkerID_equal(state->actor_mapping[actor_id].driver_id, driver_id)) {
/* This actor was created by the removed driver, so kill the actor. */
LOG_DEBUG("Killing an actor for a removed driver.");
kill_worker(state, *it, false);
break;
}
} else if (task != NULL) {
if (WorkerID_equal(TaskSpec_driver_id(Task_task_spec(task)), driver_id)) {
LOG_DEBUG("Killing a worker executing a task for a removed driver.");
kill_worker(state, *it, false);
break;
}
}
it = next_it;
}
/* Add the driver to a list of dead drivers. */
state->removed_drivers.insert(driver_id);
/* Notify the scheduling algorithm that the driver has been removed. It should
* remove tasks for that driver from its data structures. */
handle_driver_removed(state, state->algorithm_state, driver_id);
}
void handle_client_disconnect(LocalSchedulerState *state,
LocalSchedulerClient *worker) {
if (!worker->registered || worker->is_worker) {
} else {
/* In this case, a driver is disconecting. */
driver_table_send_driver_death(state->db, worker->client_id, NULL);
}
kill_worker(state, worker, false);
}
void process_message(event_loop *loop,
int client_sock,
@ -786,12 +873,7 @@ void process_message(event_loop *loop,
} break;
case DISCONNECT_CLIENT: {
LOG_INFO("Disconnecting client on fd %d", client_sock);
kill_worker(state, worker, false);
if (!ActorID_equal(worker->actor_id, NIL_ACTOR_ID)) {
/* Let the scheduling algorithm process the absence of this worker. */
handle_actor_worker_disconnect(state, state->algorithm_state,
worker->actor_id);
}
handle_client_disconnect(state, worker);
} break;
case MessageType_NotifyUnblocked: {
if (worker->task_in_progress != NULL) {
@ -827,6 +909,11 @@ void new_client_connection(event_loop *loop,
* scheduler state. */
LocalSchedulerClient *worker = new LocalSchedulerClient();
worker->sock = new_socket;
worker->registered = false;
/* We don't know whether this is a worker or not, so just initialize is_worker
* to false. */
worker->is_worker = true;
worker->client_id = NIL_WORKER_ID;
worker->task_in_progress = NULL;
worker->is_blocked = false;
worker->pid = 0;
@ -857,10 +944,21 @@ void signal_handler(int signal) {
}
}
/* End of the cleanup code. */
void handle_task_scheduled_callback(Task *original_task,
void *subscribe_context) {
LocalSchedulerState *state = (LocalSchedulerState *) subscribe_context;
TaskSpec *spec = Task_task_spec(original_task);
/* If the driver for this task has been removed, then don't bother telling the
* scheduling algorithm. */
WorkerID driver_id = TaskSpec_driver_id(spec);
if (!is_driver_alive(state, driver_id)) {
LOG_DEBUG("Ignoring scheduled task for removed driver.")
return;
}
if (ActorID_equal(TaskSpec_actor_id(spec), NIL_ACTOR_ID)) {
/* This task does not involve an actor. Handle it normally. */
handle_task_scheduled(state, state->algorithm_state, spec,
@ -884,10 +982,17 @@ void handle_task_scheduled_callback(Task *original_task,
* for creating the actor.
* @return Void.
*/
void handle_actor_creation_callback(ActorInfo info, void *context) {
ActorID actor_id = info.actor_id;
DBClientID local_scheduler_id = info.local_scheduler_id;
void handle_actor_creation_callback(ActorID actor_id,
WorkerID driver_id,
DBClientID local_scheduler_id,
void *context) {
LocalSchedulerState *state = (LocalSchedulerState *) context;
/* If the driver has been removed, don't bother doing anything. */
if (state->removed_drivers.count(driver_id) == 1) {
return;
}
/* Make sure the actor entry is not already present in the actor map table.
* TODO(rkn): We will need to remove this check to handle the case where the
* corresponding publish is retried and the case in which a task that creates
@ -897,9 +1002,10 @@ void handle_actor_creation_callback(ActorInfo info, void *context) {
* Currently this is never removed (except when the local scheduler state is
* deleted). */
ActorMapEntry entry;
entry.actor_id = actor_id;
entry.local_scheduler_id = local_scheduler_id;
entry.driver_id = driver_id;
state->actor_mapping[actor_id] = entry;
/* If this local scheduler is responsible for the actor, then start a new
* worker for the actor. */
if (DBClientID_equal(local_scheduler_id, get_db_client_id(state->db))) {
@ -960,6 +1066,11 @@ void start_server(const char *node_ip_address,
actor_notification_table_subscribe(
g_state->db, handle_actor_creation_callback, g_state, NULL);
}
/* Subscribe to notifications about removed drivers. */
if (g_state->db != NULL) {
driver_table_subscribe(g_state->db, handle_driver_removed_callback, g_state,
NULL);
}
/* Create a timer for publishing information about the load on the local
* scheduler to the local scheduler table. This message also serves as a
* heartbeat. */

View file

@ -26,6 +26,14 @@ void new_client_connection(event_loop *loop,
void *context,
int events);
/**
* Check if a driver is still alive.
*
* @param driver_id The ID of the driver.
* @return True if the driver is still alive and false otherwise.
*/
bool is_driver_alive(WorkerID driver_id);
/**
* This function can be called by the scheduling algorithm to assign a task
* to a worker.

View file

@ -964,6 +964,9 @@ void handle_worker_available(LocalSchedulerState *state,
void handle_worker_removed(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
LocalSchedulerClient *worker) {
/* Make sure this is not an actor. */
CHECK(ActorID_equal(worker->actor_id, NIL_ACTOR_ID));
/* Make sure that we remove the worker at most once. */
int num_times_removed = 0;
@ -1141,6 +1144,59 @@ void handle_object_removed(LocalSchedulerState *state,
}
}
void handle_driver_removed(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
WorkerID driver_id) {
/* Loop over fetch requests. This must be done before we clean up the waiting
* task queue and the dispatch task queue because this map contains iterators
* for those lists, which will be invalidated when we clean up those lists.*/
for (auto it = algorithm_state->remote_objects.begin();
it != algorithm_state->remote_objects.end();) {
/* Loop over the tasks that are waiting for this object and remove the tasks
* for the removed driver. */
auto task_it_it = it->second.dependent_tasks.begin();
while (task_it_it != it->second.dependent_tasks.end()) {
/* If the dependent task was a task for the removed driver, remove it from
* this vector. */
TaskSpec *spec = (*task_it_it)->spec;
if (WorkerID_equal(TaskSpec_driver_id(spec), driver_id)) {
task_it_it = it->second.dependent_tasks.erase(task_it_it);
} else {
task_it_it++;
}
}
/* If there are no more dependent tasks for this object, then remove the
* ObjectEntry. */
if (it->second.dependent_tasks.size() == 0) {
it = algorithm_state->remote_objects.erase(it);
} else {
it++;
}
}
/* Remove this driver's tasks from the waiting task queue. */
auto it = algorithm_state->waiting_task_queue->begin();
while (it != algorithm_state->waiting_task_queue->end()) {
if (WorkerID_equal(TaskSpec_driver_id(it->spec), driver_id)) {
it = algorithm_state->waiting_task_queue->erase(it);
} else {
it++;
}
}
/* Remove this driver's tasks from the dispatch task queue. */
it = algorithm_state->dispatch_task_queue->begin();
while (it != algorithm_state->dispatch_task_queue->end()) {
if (WorkerID_equal(TaskSpec_driver_id(it->spec), driver_id)) {
it = algorithm_state->dispatch_task_queue->erase(it);
} else {
it++;
}
}
/* TODO(rkn): Should we clean up the actor data structures? */
}
int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state) {
return algorithm_state->waiting_task_queue->size();
}

View file

@ -230,6 +230,20 @@ void handle_worker_unblocked(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
LocalSchedulerClient *worker);
/**
* Process the fact that a driver has been removed. This will remove all of the
* tasks for that driver from the scheduling algorithm's internal data
* structures.
*
* @param state The state of the local scheduler.
* @param algorithm_state State maintained by the scheduling algorithm.
* @param driver_id The ID of the driver that was removed.
* @return Void.
*/
void handle_driver_removed(LocalSchedulerState *state,
SchedulingAlgorithmState *algorithm_state,
WorkerID driver_id);
/**
* This function fetches queued task's missing object dependencies. It is
* called every LOCAL_SCHEDULER_FETCH_TIMEOUT_MILLISECONDS.

View file

@ -8,7 +8,9 @@
#include "utarray.h"
#include "uthash.h"
#include <list>
#include <unordered_map>
#include <unordered_set>
#include <vector>
/* These are needed to define the UT_arrays. */
@ -17,8 +19,8 @@ extern UT_icd task_ptr_icd;
/** This struct is used to maintain a mapping from actor IDs to the ID of the
* local scheduler that is responsible for the actor. */
struct ActorMapEntry {
/** The ID of the actor. This is used as a key in the hash table. */
ActorID actor_id;
/** The ID of the driver that created the actor. */
WorkerID driver_id;
/** The ID of the local scheduler that is responsible for the actor. */
DBClientID local_scheduler_id;
};
@ -47,7 +49,10 @@ struct LocalSchedulerState {
/** List of workers available to this node. This is used to free the worker
* structs when we free the scheduler state and also to access the worker
* structs in the tests. */
std::vector<LocalSchedulerClient *> workers;
std::list<LocalSchedulerClient *> workers;
/** A set of driver IDs corresponding to drivers that have been removed. This
* is used to make sure we don't execute any tasks belong to dead drivers. */
std::unordered_set<WorkerID, UniqueIDHasher> removed_drivers;
/** List of the process IDs for child processes (workers) started by the
* local scheduler that have not sent a REGISTER_PID message yet. */
std::vector<pid_t> child_pids;
@ -75,6 +80,13 @@ struct LocalSchedulerState {
struct LocalSchedulerClient {
/** The socket used to communicate with the client. */
int sock;
/** True if the client has registered and false otherwise. */
bool registered;
/** True if the client is a worker and false if it is a driver. */
bool is_worker;
/** The worker ID if the client is a worker and the driver ID if the client is
* a driver. */
WorkerID client_id;
/** A pointer to the task object that is currently running on this client. If
* no task is running on the worker, this will be NULL. This is used to
* update the task table. */

View file

@ -64,9 +64,7 @@ static void register_clients(int num_mock_workers, LocalSchedulerMock *mock) {
for (int i = 0; i < num_mock_workers; ++i) {
new_client_connection(mock->loop, mock->local_scheduler_fd,
(void *) mock->local_scheduler_state, 0);
LocalSchedulerClient *worker = mock->local_scheduler_state->workers[i];
LocalSchedulerClient *worker = mock->local_scheduler_state->workers.back();
process_message(mock->local_scheduler_state->loop, worker->sock, worker, 0);
}
}
@ -644,7 +642,7 @@ TEST start_kill_workers_test(void) {
ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(),
num_workers);
/* Make sure that the new worker registers its process ID. */
worker = local_scheduler->local_scheduler_state->workers[num_workers - 1];
worker = local_scheduler->local_scheduler_state->workers.back();
process_message(local_scheduler->local_scheduler_state->loop, worker->sock,
worker, 0);
ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0);

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import os
import re
import subprocess
@ -85,8 +86,8 @@ class DockerRunner(object):
else:
return m.group(1)
def _start_head_node(self, docker_image, mem_size, shm_size,
development_mode):
def _start_head_node(self, docker_image, mem_size, shm_size, num_cpus,
num_gpus, development_mode):
"""Start the Ray head node inside a docker container."""
mem_arg = ["--memory=" + mem_size] if mem_size else []
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
@ -94,10 +95,15 @@ class DockerRunner(object):
"{}:{}".format(os.path.dirname(os.path.realpath(__file__)),
"/ray/test/jenkins_tests")]
if development_mode else [])
proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg +
volume_arg +
[docker_image, "/ray/scripts/start_ray.sh",
"--head", "--redis-port=6379"],
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
[docker_image, "/ray/scripts/start_ray.sh", "--head",
"--redis-port=6379",
"--num-cpus={}".format(num_cpus),
"--num-gpus={}".format(num_gpus)])
print("Starting head node with command:{}".format(command))
proc = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
container_id = self._get_container_id(stdout_data)
@ -105,33 +111,34 @@ class DockerRunner(object):
raise RuntimeError("Failed to find container ID.")
self.head_container_id = container_id
self.head_container_ip = self._get_container_ip(container_id)
print("start_node", {"container_id": container_id,
"is_head": True,
"shm_size": shm_size,
"ip_address": self.head_container_ip})
return container_id
def _start_worker_node(self, docker_image, mem_size, shm_size):
def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus,
num_gpus, development_mode):
"""Start a Ray worker node inside a docker container."""
mem_arg = ["--memory=" + mem_size] if mem_size else []
shm_arg = ["--shm-size=" + shm_size] if shm_size else []
proc = subprocess.Popen(["docker", "run", "-d"] + mem_arg + shm_arg +
["--shm-size=" + shm_size, docker_image,
"/ray/scripts/start_ray.sh",
"--redis-address={:s}:6379".format(
self.head_container_ip)],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
volume_arg = (["-v",
"{}:{}".format(os.path.dirname(os.path.realpath(__file__)),
"/ray/test/jenkins_tests")]
if development_mode else [])
command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg +
["--shm-size=" + shm_size, docker_image,
"/ray/scripts/start_ray.sh",
"--redis-address={:s}:6379".format(self.head_container_ip),
"--num-cpus={}".format(num_cpus),
"--num-gpus={}".format(num_gpus)])
print("Starting worker node with command:{}".format(command))
proc = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout_data, _ = wait_for_output(proc)
container_id = self._get_container_id(stdout_data)
if container_id is None:
raise RuntimeError("Failed to find container id")
self.worker_container_ids.append(container_id)
print("start_node", {"container_id": container_id,
"is_head": False,
"shm_size": shm_size})
def start_ray(self, docker_image, mem_size, shm_size, num_nodes,
development_mode):
def start_ray(self, docker_image=None, mem_size=None, shm_size=None,
num_nodes=None, num_cpus=None, num_gpus=None,
development_mode=None):
"""Start a Ray cluster within docker.
This starts one docker container running the head node and num_nodes - 1
@ -146,15 +153,23 @@ class DockerRunner(object):
with. This will be passed into `docker run` as the `--shm-size` flag.
num_nodes: The number of nodes to use in the cluster (this counts the
head node as well).
num_cpus: A list of the number of CPUs to start each node with.
num_gpus: A list of the number of GPUs to start each node with.
development_mode: True if you want to mount the local copy of
test/jenkins_test on the head node so we can avoid rebuilding docker
images during development.
"""
assert len(num_cpus) == num_nodes
assert len(num_gpus) == num_nodes
# Launch the head node.
self._start_head_node(docker_image, mem_size, shm_size, development_mode)
self._start_head_node(docker_image, mem_size, shm_size, num_cpus[0],
num_gpus[0], development_mode)
# Start the worker nodes.
for _ in range(num_nodes - 1):
self._start_worker_node(docker_image, mem_size, shm_size)
for i in range(num_nodes - 1):
self._start_worker_node(docker_image, mem_size, shm_size,
num_cpus[1 + i], num_gpus[1 + i],
development_mode)
def _stop_node(self, container_id):
"""Stop a node in the Ray cluster."""
@ -181,32 +196,51 @@ class DockerRunner(object):
for container_id in self.worker_container_ids:
self._stop_node(container_id)
def run_test(self, test_script, run_in_docker=False):
def run_test(self, test_script, num_drivers, driver_locations=None):
"""Run a test script.
Run a test using the Ray cluster.
Args:
test_script: The test script to run.
run_in_docker: If true then the test script will be run in a docker
container. If false, it will be run regularly.
num_drivers: The number of copies of the test script to run.
driver_locations: A list of the indices of the containers that the
different copies of the test script should be run on. If this is None,
then the containers will be chosen randomly.
Returns:
A dictionary with information about the test script run.
"""
print("Starting to run test script {}.".format(test_script))
proc = subprocess.Popen(["docker", "exec", self.head_container_id,
"/bin/bash", "-c",
"RAY_REDIS_ADDRESS={}:6379 "
"python {}".format(self.head_container_ip,
test_script)],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_data, stderr_data = wait_for_output(proc)
print("STDOUT:")
print(stdout_data)
print("STDERR:")
print(stderr_data)
return {"success": proc.returncode == 0, "return_code": proc.returncode}
all_container_ids = [self.head_container_id] + self.worker_container_ids
if driver_locations is None:
driver_locations = [np.random.randint(0, len(all_container_ids))
for _ in range(num_drivers)]
# Start the different drivers.
driver_processes = []
for i in range(len(driver_locations)):
# Get the container ID to run the ith driver in.
container_id = all_container_ids[driver_locations[i]]
command = ["docker", "exec", container_id, "/bin/bash", "-c",
("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python {}"
.format(self.head_container_ip, i, test_script))]
print("Starting driver with command {}.".format(test_script))
# Start the driver.
p = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
driver_processes.append(p)
# Wait for the drivers to finish.
results = []
for p in driver_processes:
stdout_data, stderr_data = wait_for_output(p)
print("STDOUT:")
print(stdout_data)
print("STDERR:")
print(stderr_data)
results.append({"success": p.returncode == 0,
"return_code": p.returncode})
return results
if __name__ == "__main__":
@ -218,23 +252,49 @@ if __name__ == "__main__":
parser.add_argument("--shm-size", default="1G", help="shared memory size")
parser.add_argument("--num-nodes", default=1, type=int,
help="number of nodes to use in the cluster")
parser.add_argument("--num-cpus", type=str,
help=("a comma separated list of values representing "
"the number of CPUs to start each node with"))
parser.add_argument("--num-gpus", type=str,
help=("a comma separated list of values representing "
"the number of GPUs to start each node with"))
parser.add_argument("--num-drivers", default=1, type=int,
help="number of drivers to run")
parser.add_argument("--driver-locations", type=str,
help=("a comma separated list of indices of the "
"containers to run the drivers in"))
parser.add_argument("--test-script", required=True, help="test script")
parser.add_argument("--development-mode", action="store_true",
help="use local copies of the test scripts")
args = parser.parse_args()
# Parse the number of CPUs and GPUs to use for each worker.
num_nodes = args.num_nodes
num_cpus = ([int(i) for i in args.num_cpus.split(",")]
if args.num_cpus is not None else num_nodes * [10])
num_gpus = ([int(i) for i in args.num_gpus.split(",")]
if args.num_gpus is not None else num_nodes * [0])
d = DockerRunner()
d.start_ray(mem_size=args.mem_size, shm_size=args.shm_size,
num_nodes=args.num_nodes, docker_image=args.docker_image,
d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size,
shm_size=args.shm_size, num_nodes=num_nodes,
num_cpus=num_cpus, num_gpus=num_gpus,
development_mode=args.development_mode)
try:
run_result = d.run_test(args.test_script)
run_results = d.run_test(args.test_script, args.num_drivers,
driver_locations=args.driver_locations)
finally:
d.stop_ray()
if "success" in run_result and run_result["success"]:
print("RESULT: Test {} succeeded.".format(args.test_script))
sys.exit(0)
else:
print("RESULT: Test {} failed.".format(args.test_script))
any_failed = False
for run_result in run_results:
if "success" in run_result and run_result["success"]:
print("RESULT: Test {} succeeded.".format(args.test_script))
else:
print("RESULT: Test {} failed.".format(args.test_script))
any_failed = True
if any_failed:
sys.exit(1)
else:
sys.exit(0)

View file

@ -0,0 +1,80 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import ray
from ray.test.multi_node_tests import (_wait_for_nodes_to_join,
_broadcast_event,
_wait_for_event)
# This test should be run with 5 nodes, which have 0, 0, 5, 6, and 50 GPUs for
# a total of 61 GPUs. It should be run with a large number of drivers (e.g.,
# 100). At most 10 drivers will run at a time, and each driver will use at most
# 5 GPUs (this is ceil(61 / 15), which guarantees that we will always be able
# to make progress).
total_num_nodes = 5
max_concurrent_drivers = 15
num_gpus_per_driver = 5
@ray.actor(num_gpus=1)
class Actor1(object):
def __init__(self):
assert len(ray.get_gpu_ids()) == 1
def check_ids(self):
assert len(ray.get_gpu_ids()) == 1
def driver(redis_address, driver_index):
"""The script for driver 0.
This driver should create five actors that each use one GPU and some actors
that use no GPUs. After a while, it should exit.
"""
ray.init(redis_address=redis_address)
# Wait for all the nodes to join the cluster.
_wait_for_nodes_to_join(total_num_nodes)
# Limit the number of drivers running concurrently.
for i in range(driver_index - max_concurrent_drivers + 1):
_wait_for_event("DRIVER_{}_DONE".format(i), redis_address)
def try_to_create_actor(actor_class, timeout=100):
# Try to create an actor, but allow failures while we wait for the monitor
# to release the resources for the removed drivers.
start_time = time.time()
while time.time() - start_time < timeout:
try:
actor = actor_class()
except Exception as e:
time.sleep(0.1)
else:
return actor
# If we are here, then we timed out while looping.
raise Exception("Timed out while trying to create actor.")
# Create some actors that require one GPU.
actors_one_gpu = []
for _ in range(num_gpus_per_driver):
actors_one_gpu.append(try_to_create_actor(Actor1))
for _ in range(100):
ray.get([actor.check_ids() for actor in actors_one_gpu])
_broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address)
if __name__ == "__main__":
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
redis_address = os.environ["RAY_REDIS_ADDRESS"]
print("Driver {} started at {}.".format(driver_index, time.time()))
# In this test, all drivers will run the same script.
driver(redis_address, driver_index)
print("Driver {} finished at {}.".format(driver_index, time.time()))

View file

@ -0,0 +1,153 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import ray
from ray.test.multi_node_tests import (_wait_for_nodes_to_join,
_broadcast_event,
_wait_for_event)
# This test should be run with 5 nodes, which have 0, 1, 2, 3, and 4 GPUs for a
# total of 10 GPUs. It shoudl be run with 3 drivers.
total_num_nodes = 5
@ray.actor
class Actor0(object):
def __init__(self):
assert len(ray.get_gpu_ids()) == 0
def check_ids(self):
assert len(ray.get_gpu_ids()) == 0
@ray.actor(num_gpus=1)
class Actor1(object):
def __init__(self):
assert len(ray.get_gpu_ids()) == 1
def check_ids(self):
assert len(ray.get_gpu_ids()) == 1
@ray.actor(num_gpus=2)
class Actor2(object):
def __init__(self):
assert len(ray.get_gpu_ids()) == 2
def check_ids(self):
assert len(ray.get_gpu_ids()) == 2
def driver_0(redis_address):
"""The script for driver 0.
This driver should create five actors that each use one GPU and some actors
that use no GPUs. After a while, it should exit.
"""
ray.init(redis_address=redis_address)
# Wait for all the nodes to join the cluster.
_wait_for_nodes_to_join(total_num_nodes)
# Create some actors that require one GPU.
actors_one_gpu = [Actor1() for _ in range(5)]
# Create some actors that don't require any GPUs.
actors_no_gpus = [Actor0() for _ in range(5)]
for _ in range(1000):
ray.get([actor.check_ids() for actor in actors_one_gpu])
ray.get([actor.check_ids() for actor in actors_no_gpus])
_broadcast_event("DRIVER_0_DONE", redis_address)
def driver_1(redis_address):
"""The script for driver 1.
This driver should create one actor that uses two GPUs, three actors that
each use one GPU (the one requiring two must be created first), and some
actors that don't use any GPUs. After a while, it should exit.
"""
ray.init(redis_address=redis_address)
# Wait for all the nodes to join the cluster.
_wait_for_nodes_to_join(total_num_nodes)
# Create an actor that requires two GPUs.
actors_two_gpus = [Actor2() for _ in range(1)]
# Create some actors that require one GPU.
actors_one_gpu = [Actor1() for _ in range(3)]
# Create some actors that don't require any GPUs.
actors_no_gpus = [Actor0() for _ in range(5)]
for _ in range(1000):
ray.get([actor.check_ids() for actor in actors_two_gpus])
ray.get([actor.check_ids() for actor in actors_one_gpu])
ray.get([actor.check_ids() for actor in actors_no_gpus])
_broadcast_event("DRIVER_1_DONE", redis_address)
def driver_2(redis_address):
"""The script for driver 2.
This driver should wait for the first two drivers to finish. Then it should
create some actors that use a total of ten GPUs.
"""
ray.init(redis_address=redis_address)
_wait_for_event("DRIVER_0_DONE", redis_address)
_wait_for_event("DRIVER_1_DONE", redis_address)
def try_to_create_actor(actor_class, timeout=20):
# Try to create an actor, but allow failures while we wait for the monitor
# to release the resources for the removed drivers.
start_time = time.time()
while time.time() - start_time < timeout:
try:
actor = actor_class()
except Exception as e:
time.sleep(0.1)
else:
return actor
# If we are here, then we timed out while looping.
raise Exception("Timed out while trying to create actor.")
# Create some actors that require two GPUs.
actors_two_gpus = []
for _ in range(3):
actors_two_gpus.append(try_to_create_actor(Actor2))
# Create some actors that require one GPU.
actors_one_gpu = []
for _ in range(4):
actors_one_gpu.append(try_to_create_actor(Actor1))
# Create some actors that don't require any GPUs.
actors_no_gpus = [Actor0() for _ in range(5)]
for _ in range(1000):
ray.get([actor.check_ids() for actor in actors_two_gpus])
ray.get([actor.check_ids() for actor in actors_one_gpu])
ray.get([actor.check_ids() for actor in actors_no_gpus])
_broadcast_event("DRIVER_2_DONE", redis_address)
if __name__ == "__main__":
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
redis_address = os.environ["RAY_REDIS_ADDRESS"]
print("Driver {} started at {}.".format(driver_index, time.time()))
if driver_index == 0:
driver_0(redis_address)
elif driver_index == 1:
driver_1(redis_address)
elif driver_index == 2:
driver_2(redis_address)
else:
raise Exception("This code should be unreachable.")
print("Driver {} finished at {}.".format(driver_index, time.time()))

View file

@ -15,7 +15,11 @@ def f():
if __name__ == "__main__":
ray.init(redis_address=os.environ["RAY_REDIS_ADDRESS"])
driver_index = int(os.environ["RAY_DRIVER_INDEX"])
redis_address = os.environ["RAY_REDIS_ADDRESS"]
print("Driver {} started at {}.".format(driver_index, time.time()))
ray.init(redis_address=redis_address)
# Check that tasks are scheduled on all nodes.
num_attempts = 30
for i in range(num_attempts):
@ -26,3 +30,5 @@ if __name__ == "__main__":
if len(counts) == 5:
break
assert len(counts) == 5
print("Driver {} finished at {}.".format(driver_index, time.time()))

View file

@ -13,6 +13,20 @@ python $ROOT_DIR/multi_node_docker_test.py \
--num-nodes=5 \
--test-script=/ray/test/jenkins_tests/multi_node_tests/test_0.py
python $ROOT_DIR/multi_node_docker_test.py \
--docker-image=$DOCKER_SHA \
--num-nodes=5 \
--num-gpus=0,1,2,3,4 \
--num-drivers=3 \
--test-script=/ray/test/jenkins_tests/multi_node_tests/remove_driver_test.py
python $ROOT_DIR/multi_node_docker_test.py \
--docker-image=$DOCKER_SHA \
--num-nodes=5 \
--num-gpus=0,0,5,6,50 \
--num-drivers=100 \
--test-script=/ray/test/jenkins_tests/multi_node_tests/many_drivers_test.py
python $ROOT_DIR/multi_node_docker_test.py \
--docker-image=$DOCKER_SHA \
--num-nodes=1 \

View file

@ -1433,8 +1433,7 @@ class GlobalStateAPI(unittest.TestCase):
client_table = ray.global_state.client_table()
node_ip_address = ray.worker.global_worker.node_ip_address
self.assertEqual(len(client_table[node_ip_address]), 2)
self.assertEqual(len(client_table[":"]), 1)
self.assertEqual(len(client_table[node_ip_address]), 3)
manager_client = [c for c in client_table[node_ip_address]
if c["ClientType"] == "plasma_manager"][0]