Ray scheduler spillback plumbing + mechanism (#1362)

* spillback mechanism and plumbing : adding spillback counter + timestamp

* linting fix

* documentation

* Fix argument name.
This commit is contained in:
Alexey Tumanov 2018-01-23 20:18:12 -08:00 committed by Robert Nishihara
parent 21a916009e
commit f1303291b4
11 changed files with 153 additions and 48 deletions

View file

@ -322,16 +322,18 @@ class TestGlobalStateStore(unittest.TestCase):
def check_task_reply(message, task_args, updated=False):
(task_status, local_scheduler_id, execution_dependencies_string,
task_spec) = task_args
spillback_count, task_spec) = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
self.assertEqual(task_reply_object.SpillbackCount(),
spillback_count)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"", b"task_spec"]
task_args = [TASK_STATUS_WAITING, b"node_id", b"", 0, b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
@ -339,7 +341,7 @@ class TestGlobalStateStore(unittest.TestCase):
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
*task_args[:3])
*task_args[:4])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
@ -408,7 +410,7 @@ class TestGlobalStateStore(unittest.TestCase):
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"", b"task_spec"]
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
@ -420,7 +422,7 @@ class TestGlobalStateStore(unittest.TestCase):
task_args[2])
self.assertEqual(notification_object.ExecutionDependencies(),
task_args[3])
self.assertEqual(notification_object.TaskSpec(), task_args[4])
self.assertEqual(notification_object.TaskSpec(), task_args[-1])
def testTaskTableSubscribe(self):
scheduling_state = 1

View file

@ -273,6 +273,8 @@ class GlobalState(object):
task_table_message.LocalSchedulerId()),
"ExecutionDependenciesString":
task_table_message.ExecutionDependencies(),
"SpillbackCount":
task_table_message.SpillbackCount(),
"TaskSpec": task_spec_info}
def task_table(self, task_id=None):

View file

@ -198,7 +198,8 @@ class Monitor(object):
key, "RAY.TASK_TABLE_UPDATE",
hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
task["ExecutionDependenciesString"])
task["ExecutionDependenciesString"],
task["SpillbackCount"])
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1

View file

@ -1906,6 +1906,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
TASK_STATUS_RUNNING,
NIL_LOCAL_SCHEDULER_ID,
driver_task.execution_dependencies_string(),
0,
ray.local_scheduler.task_to_string(driver_task))
# Set the driver's current task ID to the task ID assigned to the
# driver task.

View file

@ -101,6 +101,8 @@ table TaskReply {
execution_dependencies: string;
// A string of bytes representing the task specification.
task_spec: string;
// The number of times the task was spilled back by local schedulers.
spillback_count: long;
// A boolean representing whether the update was successful. This field
// should only be used for test-and-set operations.
updated: bool;

View file

@ -762,12 +762,14 @@ int ReplyWithTask(RedisModuleCtx *ctx,
RedisModuleString *local_scheduler_id = NULL;
RedisModuleString *execution_dependencies = NULL;
RedisModuleString *task_spec = NULL;
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &state,
"local_scheduler_id", &local_scheduler_id,
"execution_dependencies", &execution_dependencies,
"TaskSpec", &task_spec, NULL);
RedisModuleString *spillback_count = NULL;
RedisModule_HashGet(
key, REDISMODULE_HASH_CFIELDS, "state", &state, "local_scheduler_id",
&local_scheduler_id, "execution_dependencies", &execution_dependencies,
"TaskSpec", &task_spec, "spillback_count", &spillback_count, NULL);
if (state == NULL || local_scheduler_id == NULL ||
execution_dependencies == NULL || task_spec == NULL) {
execution_dependencies == NULL || task_spec == NULL ||
spillback_count == NULL) {
/* We must have either all fields or no fields. */
RedisModule_CloseKey(key);
return RedisModule_ReplyWithError(
@ -775,22 +777,29 @@ int ReplyWithTask(RedisModuleCtx *ctx,
}
long long state_integer;
if (RedisModule_StringToLongLong(state, &state_integer) != REDISMODULE_OK ||
state_integer < 0) {
long long spillback_count_val;
if ((RedisModule_StringToLongLong(state, &state_integer) !=
REDISMODULE_OK) ||
(state_integer < 0) ||
(RedisModule_StringToLongLong(spillback_count, &spillback_count_val) !=
REDISMODULE_OK) ||
(spillback_count_val < 0)) {
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
RedisModule_FreeString(ctx, local_scheduler_id);
RedisModule_FreeString(ctx, execution_dependencies);
RedisModule_FreeString(ctx, task_spec);
return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state.");
RedisModule_FreeString(ctx, spillback_count);
return RedisModule_ReplyWithError(
ctx, "Found invalid scheduling state or spillback count.");
}
flatbuffers::FlatBufferBuilder fbb;
auto message =
CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_integer,
RedisStringToFlatbuf(fbb, local_scheduler_id),
RedisStringToFlatbuf(fbb, execution_dependencies),
RedisStringToFlatbuf(fbb, task_spec), updated);
auto message = CreateTaskReply(
fbb, RedisStringToFlatbuf(fbb, task_id), state_integer,
RedisStringToFlatbuf(fbb, local_scheduler_id),
RedisStringToFlatbuf(fbb, execution_dependencies),
RedisStringToFlatbuf(fbb, task_spec), spillback_count_val, updated);
fbb.Finish(message);
RedisModuleString *reply = RedisModule_CreateString(
@ -801,6 +810,7 @@ int ReplyWithTask(RedisModuleCtx *ctx,
RedisModule_FreeString(ctx, local_scheduler_id);
RedisModule_FreeString(ctx, execution_dependencies);
RedisModule_FreeString(ctx, task_spec);
RedisModule_FreeString(ctx, spillback_count);
} else {
/* If the key does not exist, return nil. */
RedisModule_ReplyWithNull(ctx);
@ -911,12 +921,19 @@ int TaskTableWrite(RedisModuleCtx *ctx,
RedisModuleString *state,
RedisModuleString *local_scheduler_id,
RedisModuleString *execution_dependencies,
RedisModuleString *spillback_count,
RedisModuleString *task_spec) {
/* Extract the scheduling state. */
long long state_value;
if (RedisModule_StringToLongLong(state, &state_value) != REDISMODULE_OK) {
return RedisModule_ReplyWithError(ctx, "scheduling state must be integer");
}
long long spillback_count_value;
if (RedisModule_StringToLongLong(spillback_count, &spillback_count_value) !=
REDISMODULE_OK) {
return RedisModule_ReplyWithError(ctx, "spillback count must be integer");
}
/* Add the task to the task table. If no spec was provided, get the existing
* spec out of the task table so we can publish it. */
RedisModuleString *existing_task_spec = NULL;
@ -925,7 +942,8 @@ int TaskTableWrite(RedisModuleCtx *ctx,
if (task_spec == NULL) {
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state,
"local_scheduler_id", local_scheduler_id,
"execution_dependencies", execution_dependencies, NULL);
"execution_dependencies", execution_dependencies,
"spillback_count", spillback_count, NULL);
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "TaskSpec",
&existing_task_spec, NULL);
if (existing_task_spec == NULL) {
@ -934,10 +952,10 @@ int TaskTableWrite(RedisModuleCtx *ctx,
ctx, "Cannot update a task that doesn't exist yet");
}
} else {
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state,
"local_scheduler_id", local_scheduler_id,
"execution_dependencies", execution_dependencies,
"TaskSpec", task_spec, NULL);
RedisModule_HashSet(
key, REDISMODULE_HASH_CFIELDS, "state", state, "local_scheduler_id",
local_scheduler_id, "execution_dependencies", execution_dependencies,
"TaskSpec", task_spec, "spillback_count", spillback_count, NULL);
}
RedisModule_CloseKey(key);
@ -959,11 +977,12 @@ int TaskTableWrite(RedisModuleCtx *ctx,
task_spec_to_use = existing_task_spec;
}
/* Create the flatbuffers message. */
auto message =
CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_value,
RedisStringToFlatbuf(fbb, local_scheduler_id),
RedisStringToFlatbuf(fbb, execution_dependencies),
RedisStringToFlatbuf(fbb, task_spec_to_use));
auto message = CreateTaskReply(
fbb, RedisStringToFlatbuf(fbb, task_id), state_value,
RedisStringToFlatbuf(fbb, local_scheduler_id),
RedisStringToFlatbuf(fbb, execution_dependencies),
RedisStringToFlatbuf(fbb, task_spec_to_use), spillback_count_value,
true); // The updated field is not used.
fbb.Finish(message);
RedisModuleString *publish_message = RedisModule_CreateString(
@ -1023,11 +1042,12 @@ int TaskTableWrite(RedisModuleCtx *ctx,
int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc != 6) {
if (argc != 7) {
return RedisModule_WrongArity(ctx);
}
return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5]);
return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5],
argv[6]);
}
/**
@ -1051,11 +1071,11 @@ int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx,
int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc != 5) {
if (argc != 6) {
return RedisModule_WrongArity(ctx);
}
return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], NULL);
return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5], NULL);
}
/**

View file

@ -943,10 +943,12 @@ void redis_task_table_add_task(TableCallbackData *callback_data) {
int status = redisAsyncCommand(
context, redis_task_table_add_task_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b %b",
(void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b %d %b",
task_id.data(), sizeof(task_id), state, local_scheduler_id.data(),
sizeof(local_scheduler_id), fbb.GetBufferPointer(),
(size_t) fbb.GetSize(), spec, execution_spec->SpecSize());
(size_t) fbb.GetSize(),
static_cast<int>(execution_spec->SpillbackCount()), spec,
execution_spec->SpecSize());
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task");
}
@ -1005,10 +1007,11 @@ void redis_task_table_update(TableCallbackData *callback_data) {
int status = redisAsyncCommand(
context, redis_task_table_update_callback,
(void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b %b",
(void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b %b %d",
task_id.data(), sizeof(task_id), state, local_scheduler_id.data(),
sizeof(local_scheduler_id), fbb.GetBufferPointer(),
(size_t) fbb.GetSize());
(size_t) fbb.GetSize(),
static_cast<int>(execution_spec->SpillbackCount()));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error in redis_task_table_update");
}
@ -1114,10 +1117,16 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c,
/* Extract the task spec. */
TaskSpec *spec = (TaskSpec *) message->task_spec()->data();
int64_t task_spec_size = message->task_spec()->size();
/* Extract the spillback information. */
int spillback_count = message->spillback_count();
/* Create a task. */
Task *task = Task_alloc(
spec, task_spec_size, state, local_scheduler_id,
from_flatbuf(*execution_dependencies->execution_dependencies()));
/* Allocate the task execution spec on the stack and use it to construct
* the task.
*/
TaskExecutionSpec execution_spec(
from_flatbuf(*execution_dependencies->execution_dependencies()), spec,
task_spec_size, spillback_count);
Task *task = Task_alloc(execution_spec, state, local_scheduler_id);
/* Call the subscribe callback if there is one. */
TaskTableSubscribeData *data =

View file

@ -367,17 +367,28 @@ void TaskSpec_free(TaskSpec *spec) {
TaskExecutionSpec::TaskExecutionSpec(
const std::vector<ObjectID> &execution_dependencies,
TaskSpec *spec,
int64_t task_spec_size) {
execution_dependencies_ = execution_dependencies;
task_spec_size_ = task_spec_size;
int64_t task_spec_size,
int spillback_count)
: execution_dependencies_(execution_dependencies),
task_spec_size_(task_spec_size),
last_timestamp_(0),
spillback_count_(spillback_count) {
TaskSpec *spec_copy = new TaskSpec[task_spec_size_];
memcpy(spec_copy, spec, task_spec_size);
spec_ = std::unique_ptr<TaskSpec[]>(spec_copy);
}
TaskExecutionSpec::TaskExecutionSpec(TaskExecutionSpec *other) {
execution_dependencies_ = other->execution_dependencies_;
task_spec_size_ = other->task_spec_size_;
TaskExecutionSpec::TaskExecutionSpec(
const std::vector<ObjectID> &execution_dependencies,
TaskSpec *spec,
int64_t task_spec_size)
: TaskExecutionSpec(execution_dependencies, spec, task_spec_size, 0) {}
TaskExecutionSpec::TaskExecutionSpec(TaskExecutionSpec *other)
: execution_dependencies_(other->execution_dependencies_),
task_spec_size_(other->task_spec_size_),
last_timestamp_(other->last_timestamp_),
spillback_count_(other->spillback_count_) {
TaskSpec *spec_copy = new TaskSpec[task_spec_size_];
memcpy(spec_copy, other->spec_.get(), task_spec_size_);
spec_ = std::unique_ptr<TaskSpec[]>(spec_copy);
@ -392,10 +403,26 @@ void TaskExecutionSpec::SetExecutionDependencies(
execution_dependencies_ = dependencies;
}
int64_t TaskExecutionSpec::SpecSize() {
int64_t TaskExecutionSpec::SpecSize() const {
return task_spec_size_;
}
int TaskExecutionSpec::SpillbackCount() const {
return spillback_count_;
}
void TaskExecutionSpec::IncrementSpillbackCount() {
++spillback_count_;
}
int64_t TaskExecutionSpec::LastTimeStamp() const {
return last_timestamp_;
}
void TaskExecutionSpec::SetLastTimeStamp(int64_t new_timestamp) {
last_timestamp_ = new_timestamp;
}
TaskSpec *TaskExecutionSpec::Spec() {
return spec_.get();
}

View file

@ -20,6 +20,10 @@ class TaskExecutionSpec {
TaskExecutionSpec(const std::vector<ObjectID> &execution_dependencies,
TaskSpec *spec,
int64_t task_spec_size);
TaskExecutionSpec(const std::vector<ObjectID> &execution_dependencies,
TaskSpec *spec,
int64_t task_spec_size,
int spillback_count);
TaskExecutionSpec(TaskExecutionSpec *execution_spec);
/// Get the task's execution dependencies.
@ -37,7 +41,31 @@ class TaskExecutionSpec {
/// Get the task spec size.
///
/// @return The size of the immutable task spec.
int64_t SpecSize();
int64_t SpecSize() const;
/// Get the task's spillback count, which tracks the number of times
/// this task was spilled back from local to the global scheduler.
///
/// @return The spillback count for this task.
int SpillbackCount() const;
/// Increment the spillback count for this task.
///
/// @return Void.
void IncrementSpillbackCount();
/// Get the task's last timestamp.
///
/// @return The timestamp when this task was last received for scheduling.
int64_t LastTimeStamp() const;
/// Set the task's last timestamp to the specified value.
///
/// @param new_timestamp The new timestamp in millisecond to set the task's
/// time stamp to. Tracks the last time this task entered a local
/// scheduler.
/// @return Void.
void SetLastTimeStamp(int64_t new_timestamp);
/// Get the task spec.
///
@ -84,6 +112,10 @@ class TaskExecutionSpec {
std::vector<ObjectID> execution_dependencies_;
/** The size of the task specification for this task. */
int64_t task_spec_size_;
/** Last time this task was received for scheduling. */
int64_t last_timestamp_;
/** Number of times this task was spilled back by local schedulers. */
int spillback_count_;
/** The task specification for this task. */
std::unique_ptr<TaskSpec[]> spec_;
};

View file

@ -1007,6 +1007,8 @@ void process_message(event_loop *loop,
TaskExecutionSpec(from_flatbuf(*message->execution_dependencies()),
(TaskSpec *) message->task_spec()->data(),
message->task_spec()->size());
/* Set the tasks's local scheduler entrypoint time. */
execution_spec.SetLastTimeStamp(current_time_ms());
TaskSpec *spec = execution_spec.Spec();
/* Update the result table, which holds mappings of object ID -> ID of the
* task that created it. */
@ -1197,6 +1199,9 @@ void handle_task_scheduled_callback(Task *original_task,
TaskExecutionSpec *execution_spec = Task_task_execution_spec(original_task);
TaskSpec *spec = execution_spec->Spec();
/* Set the tasks's local scheduler entrypoint time. */
execution_spec->SetLastTimeStamp(current_time_ms());
/* If the driver for this task has been removed, then don't bother telling the
* scheduling algorithm. */
WorkerID driver_id = TaskSpec_driver_id(spec);

View file

@ -1067,6 +1067,10 @@ void give_task_to_global_scheduler(LocalSchedulerState *state,
}
/* Pass on the task to the global scheduler. */
DCHECK(state->config.global_scheduler_exists);
/* Increment the task's spillback count before forwarding it to the global
* scheduler.
*/
execution_spec.IncrementSpillbackCount();
Task *task =
Task_alloc(execution_spec, TASK_STATUS_WAITING, DBClientID::nil());
DCHECK(state->db != NULL);