Properly import remote functions and reusable variables on workers that register late (#290)

This commit is contained in:
Robert Nishihara 2016-07-25 16:17:17 -07:00 committed by Philipp Moritz
parent 5591aa4665
commit 8e9f98c5ff
8 changed files with 145 additions and 76 deletions

View file

@ -686,31 +686,27 @@ def main_loop(worker=global_worker):
# above so that changes made to their state do not affect other tasks.
ray.reusables._reinitialize()
while True:
(task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle)
command, command_args = ray.lib.wait_for_next_message(worker.handle)
try:
# Only one of task, function, and reusable_variable should be not None.
assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1
if task is None and function is None and reusable_variable is None:
# We use this as a mechanism to allow the scheduler to kill workers. When
# the scheduler wants to kill a worker, it gives the worker a null task,
# causing the worker program to exit the main loop here.
if command == "die":
# We use this as a mechanism to allow the scheduler to kill workers.
break
if function is not None:
(function, arg_types, return_types) = pickling.loads(function)
elif command == "function":
(function, arg_types, return_types) = pickling.loads(command_args)
if function.__module__ is None: function.__module__ = "__main__"
worker.register_function(remote(arg_types, return_types, worker)(function))
if reusable_variable is not None:
name, initializer_str, reinitializer_str = reusable_variable
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
initializer = pickling.loads(initializer_str)
reinitializer = pickling.loads(reinitializer_str)
reusables.__setattr__(name, Reusable(initializer, reinitializer))
if task is not None:
process_task(task)
elif command == "task":
process_task(command_args)
else:
assert False, "This code should be unreachable."
finally:
# Allow releasing the variables BEFORE we wait for the next message or exit the block
del task
del function
del reusable_variable
del command_args
def _submit_task(func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.

View file

@ -240,9 +240,7 @@ message ExportFunctionReply {
}
message ExportReusableVariableRequest {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
ReusableVar reusable_variable = 1; // The reusable variable to export.
}
// These messages are for getting information about the object store state
@ -280,9 +278,7 @@ message ImportFunctionReply {
}
message ImportReusableVariableRequest {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
ReusableVar reusable_variable = 1; // The reusable variable to export.
}
message DieRequest {
@ -290,3 +286,13 @@ message DieRequest {
message DieReply {
}
// This message is used by the worker service to send messages to the worker
// that are processed by the worker's main loop.
message WorkerMessage {
oneof worker_item {
Task task = 1; // A task for the worker to execute.
Function function = 2; // A remote function to import on the worker.
ReusableVar reusable_variable = 3; // A reusable variable to import on the worker.
}
}

View file

@ -37,6 +37,12 @@ message Function {
bytes implementation = 1;
}
message ReusableVar {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
}
// Union of possible object types
message Obj {
String string_data = 1;

View file

@ -607,13 +607,13 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
}
static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) {
std::vector<ObjRef> objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value.
PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size());
int argsize = task->arg_size();
PyObject* string = PyString_FromStringAndSize(task.name().c_str(), task.name().size());
int argsize = task.arg_size();
PyObject* arglist = PyList_New(argsize);
for (int i = 0; i < argsize; ++i) {
const Value& val = task->arg(i);
const Value& val = task.arg(i);
if (!val.has_obj()) {
PyList_SetItem(arglist, i, make_pyobjref(worker_capsule, val.ref()));
objrefs.push_back(val.ref());
@ -624,12 +624,12 @@ static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
Worker* worker;
PyObjectToWorker(worker_capsule, &worker);
worker->decrement_reference_count(objrefs);
int resultsize = task->result_size();
int resultsize = task.result_size();
std::vector<ObjRef> result_objrefs;
PyObject* resultlist = PyList_New(resultsize);
for (int i = 0; i < resultsize; ++i) {
PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task->result(i)));
result_objrefs.push_back(task->result(i));
PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task.result(i)));
result_objrefs.push_back(task.result(i));
}
worker->decrement_reference_count(result_objrefs); // The corresponding increment is done in SubmitTask in the scheduler.
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
@ -685,23 +685,32 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
Worker* worker;
PyObjectToWorker(worker_capsule, &worker);
if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) {
PyObject* variable_info;
if (!message->reusable_variable.variable_name.empty()) {
variable_info = PyTuple_New(3);
PyTuple_SetItem(variable_info, 0, PyString_FromStringAndSize(message->reusable_variable.variable_name.data(), static_cast<ssize_t>(message->reusable_variable.variable_name.size())));
PyTuple_SetItem(variable_info, 1, PyString_FromStringAndSize(message->reusable_variable.initializer.data(), static_cast<ssize_t>(message->reusable_variable.initializer.size())));
PyTuple_SetItem(variable_info, 2, PyString_FromStringAndSize(message->reusable_variable.reinitializer.data(), static_cast<ssize_t>(message->reusable_variable.reinitializer.size())));
bool task_present = !message->task().name().empty();
bool function_present = !message->function().implementation().empty();
bool reusable_variable_present = !message->reusable_variable().name().empty();
RAY_CHECK(task_present + function_present + reusable_variable_present <= 1, "The worker message should contain at most one item.");
PyObject* t = PyTuple_New(2);
if (task_present) {
PyTuple_SetItem(t, 0, PyString_FromString("task"));
PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task()));
} else if (function_present) {
PyTuple_SetItem(t, 0, PyString_FromString("function"));
PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
} else if (reusable_variable_present) {
PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable"));
PyObject* reusable_variable = PyTuple_New(3);
PyTuple_SetItem(reusable_variable, 0, PyString_FromStringAndSize(message->reusable_variable().name().data(), static_cast<ssize_t>(message->reusable_variable().name().size())));
PyTuple_SetItem(reusable_variable, 1, PyString_FromStringAndSize(message->reusable_variable().initializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().initializer().implementation().size())));
PyTuple_SetItem(reusable_variable, 2, PyString_FromStringAndSize(message->reusable_variable().reinitializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().reinitializer().implementation().size())));
PyTuple_SetItem(t, 1, reusable_variable);
} else {
PyTuple_SetItem(t, 0, PyString_FromString("die"));
Py_INCREF(Py_None);
PyTuple_SetItem(t, 1, Py_None);
}
// The tuple constructed below will take ownership of some None objects.
// When the tuple goes out of scope, the reference count for None will be
// decremented. Therefore, we need to increment the reference count for None
// every time we put a None in the tuple.
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_INCREF(Py_None), Py_None : deserialize_task(worker_capsule, &message->task));
PyTuple_SetItem(t, 1, message->function.empty() ? Py_INCREF(Py_None), Py_None : PyString_FromStringAndSize(message->function.data(), static_cast<ssize_t>(message->function.size())));
PyTuple_SetItem(t, 2, message->reusable_variable.variable_name.empty() ? Py_INCREF(Py_None), Py_None : variable_info);
return t;
}
RAY_CHECK(false, "This code should be unreachable.");
Py_RETURN_NONE;
}

View file

@ -150,6 +150,20 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN
OperationId operationid = (*workers_.get())[workerid].current_task;
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
{
// Check if the worker has been initialized yet, and if not, then give it
// all of the exported functions and all of the exported reusable variables.
auto workers = workers_.get();
if (!(*workers)[workerid].initialized) {
// This should only happen once.
// Import all remote functions on the worker.
export_all_functions_to_worker(workerid, workers, exported_functions_.get());
// Import all reusable variables on the worker.
export_all_reusable_variables_to_worker(workerid, workers, exported_reusable_variables_.get());
// Mark the worker as initialized.
(*workers)[workerid].initialized = true;
}
}
if (request->has_previous_task_info()) {
RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION.");
std::string task_name;
@ -293,13 +307,12 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
auto workers = workers_.get();
auto exported_functions = exported_functions_.get();
// TODO(rkn): Does this do a deep copy?
exported_functions->push_back(std::unique_ptr<Function>(new Function(request->function())));
for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->set_implementation(request->function().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) {
ImportFunctionReply import_reply;
(*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
export_function_to_worker(i, exported_functions->size() - 1, workers, exported_functions);
}
}
return Status::OK;
@ -307,15 +320,12 @@ Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunc
Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) {
auto workers = workers_.get();
auto exported_reusable_variables = exported_reusable_variables_.get();
// TODO(rkn): Does this do a deep copy?
exported_reusable_variables->push_back(std::unique_ptr<ReusableVar>(new ReusableVar(request->reusable_variable())));
for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.set_name(request->name());
import_request.mutable_initializer()->set_implementation(request->initializer().implementation());
import_request.mutable_reinitializer()->set_implementation(request->reinitializer().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) {
AckReply import_reply;
(*workers)[i].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
export_reusable_variable_to_worker(i, exported_reusable_variables->size() - 1, workers, exported_reusable_variables);
}
}
return Status::OK;
@ -451,6 +461,7 @@ std::pair<WorkerId, ObjStoreId> SchedulerService::register_worker(const std::str
(*workers)[workerid].objstoreid = objstoreid;
(*workers)[workerid].worker_stub = WorkerService::NewStub(channel);
(*workers)[workerid].worker_address = worker_address;
(*workers)[workerid].initialized = false;
if (is_driver) {
(*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers.
} else {
@ -830,6 +841,37 @@ void SchedulerService::get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef>
upstream_objrefs(downstream_objref, equivalent_objrefs, reverse_target_objrefs_.get());
}
void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid);
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
ImportFunctionReply import_reply;
(*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
}
void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
RAY_LOG(RAY_INFO, "exporting reusable variable with index " << reusable_variable_index << " to worker " << workerid);
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get());
AckReply import_reply;
(*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
}
void SchedulerService::export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
for (int i = 0; i < exported_functions->size(); ++i) {
export_function_to_worker(workerid, i, workers, exported_functions);
}
}
void SchedulerService::export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
for (int i = 0; i < exported_reusable_variables->size(); ++i) {
export_reusable_variable_to_worker(workerid, i, workers, exported_reusable_variables);
}
}
// This method defines the order in which locks should be acquired.
void SchedulerService::do_on_locks(bool lock) {
std::mutex *mutexes[] = {
@ -847,7 +889,9 @@ void SchedulerService::do_on_locks(bool lock) {
&objtable_.mutex(),
&objstores_.mutex(),
&target_objrefs_.mutex(),
&reverse_target_objrefs_.mutex()
&reverse_target_objrefs_.mutex(),
&exported_functions_.mutex(),
&exported_reusable_variables_.mutex(),
};
size_t n = sizeof(mutexes) / sizeof(*mutexes);
for (size_t i = 0; i != n; ++i) {

View file

@ -37,6 +37,10 @@ struct WorkerHandle {
std::unique_ptr<WorkerService::Stub> worker_stub; // If null, the worker has died
ObjStoreId objstoreid;
std::string worker_address;
// This field is initialized to false, and it is set to true after all of the
// exported functions and exported reusable variables have been shipped to
// this worker.
bool initialized;
OperationId current_task;
};
@ -129,6 +133,20 @@ private:
void upstream_objrefs(ObjRef objref, std::vector<ObjRef> &objrefs, const SynchronizedPtr<std::vector<std::vector<ObjRef> > > &reverse_target_objrefs);
// Find all of the object references that refer to the same object as objref (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known.
void get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef> &equivalent_objrefs);
// Export a remote function to a worker.
void export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export a reusable variable to a worker
void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// Export all reusable variables to a worker. This is used when a new worker
// registers and is protected by the workers lock (which is passed in) to
// ensure that no other reusable variables are exported to the worker while
// this method is being called.
void export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export all remote functions to a worker. This is used when a new worker
// registers and is protected by the workers lock (which is passed in) to
// ensure that no other remote functions are exported to the worker while this
// method is being called.
void export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// acquires all locks, this should only be used by get_info and for fault tolerance
void acquire_all_locks();
// release all locks, this should only be used by get_info and for fault tolerance
@ -187,6 +205,10 @@ private:
Synchronized<std::vector<RefCount> > reference_counts_;
// contained_objrefs_[objref] is a vector of all of the objrefs contained inside the object referred to by objref
Synchronized<std::vector<std::vector<ObjRef> > > contained_objrefs_;
// All of the remote functions that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<Function> > > exported_functions_;
// All of the reusable variables that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<ReusableVar> > > exported_reusable_variables_;
// the scheduling algorithm that will be used
SchedulingAlgorithmType scheduling_algorithm_;
};

View file

@ -17,7 +17,7 @@ inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address)
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->task = request->task();
message->mutable_task()->CopyFrom(request->task());
{
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
@ -28,7 +28,7 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR
Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) {
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->function = request->function().implementation();
message->mutable_function()->CopyFrom(request->function());
RAY_LOG(RAY_INFO, "importing function");
{
WorkerMessage* message_ptr = message.get();
@ -40,9 +40,7 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun
Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) {
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->reusable_variable.variable_name = request->name();
message->reusable_variable.initializer = request->initializer().implementation();
message->reusable_variable.reinitializer = request->reinitializer().implementation();
message->mutable_reusable_variable()->CopyFrom(request->reusable_variable());
RAY_LOG(RAY_INFO, "importing reusable variable");
{
WorkerMessage* message_ptr = message.get();
@ -360,9 +358,9 @@ void Worker::export_reusable_variable(const std::string& name, const std::string
RAY_CHECK(connected_, "Attempted to export reusable variable but failed.");
ClientContext context;
ExportReusableVariableRequest request;
request.set_name(name);
request.mutable_initializer()->set_implementation(initializer);
request.mutable_reinitializer()->set_implementation(reinitializer);
request.mutable_reusable_variable()->set_name(name);
request.mutable_reusable_variable()->mutable_initializer()->set_implementation(initializer);
request.mutable_reusable_variable()->mutable_reinitializer()->set_implementation(reinitializer);
AckReply reply;
Status status = scheduler_stub_->ExportReusableVariable(&context, request, &reply);
}

View file

@ -23,18 +23,6 @@ using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientWriter;
struct ReusableVariable {
std::string variable_name;
std::string initializer;
std::string reinitializer;
};
struct WorkerMessage {
Task task;
std::string function; // Used for importing remote functions.
ReusableVariable reusable_variable; // Used for importing reusable variables.
};
class WorkerServiceImpl final : public WorkerService::Service {
public:
WorkerServiceImpl(const std::string& worker_address);