diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 19ebe5cc7..876dfdb16 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -686,31 +686,27 @@ def main_loop(worker=global_worker): # above so that changes made to their state do not affect other tasks. ray.reusables._reinitialize() while True: - (task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle) + command, command_args = ray.lib.wait_for_next_message(worker.handle) try: - # Only one of task, function, and reusable_variable should be not None. - assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1 - if task is None and function is None and reusable_variable is None: - # We use this as a mechanism to allow the scheduler to kill workers. When - # the scheduler wants to kill a worker, it gives the worker a null task, - # causing the worker program to exit the main loop here. + if command == "die": + # We use this as a mechanism to allow the scheduler to kill workers. break - if function is not None: - (function, arg_types, return_types) = pickling.loads(function) + elif command == "function": + (function, arg_types, return_types) = pickling.loads(command_args) if function.__module__ is None: function.__module__ = "__main__" worker.register_function(remote(arg_types, return_types, worker)(function)) - if reusable_variable is not None: - name, initializer_str, reinitializer_str = reusable_variable + elif command == "reusable_variable": + name, initializer_str, reinitializer_str = command_args initializer = pickling.loads(initializer_str) reinitializer = pickling.loads(reinitializer_str) reusables.__setattr__(name, Reusable(initializer, reinitializer)) - if task is not None: - process_task(task) + elif command == "task": + process_task(command_args) + else: + assert False, "This code should be unreachable." finally: # Allow releasing the variables BEFORE we wait for the next message or exit the block - del task - del function - del reusable_variable + del command_args def _submit_task(func_name, args, worker=global_worker): """This is a wrapper around worker.submit_task. diff --git a/protos/ray.proto b/protos/ray.proto index 6ef79ce24..05c95797e 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -240,9 +240,7 @@ message ExportFunctionReply { } message ExportReusableVariableRequest { - string name = 1; // The name of the reusable variable. - Function initializer = 2; // A serialized version of the function that initializes the reusable variable. - Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable. + ReusableVar reusable_variable = 1; // The reusable variable to export. } // These messages are for getting information about the object store state @@ -280,9 +278,7 @@ message ImportFunctionReply { } message ImportReusableVariableRequest { - string name = 1; // The name of the reusable variable. - Function initializer = 2; // A serialized version of the function that initializes the reusable variable. - Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable. + ReusableVar reusable_variable = 1; // The reusable variable to export. } message DieRequest { @@ -290,3 +286,13 @@ message DieRequest { message DieReply { } + +// This message is used by the worker service to send messages to the worker +// that are processed by the worker's main loop. +message WorkerMessage { + oneof worker_item { + Task task = 1; // A task for the worker to execute. + Function function = 2; // A remote function to import on the worker. + ReusableVar reusable_variable = 3; // A reusable variable to import on the worker. + } +} diff --git a/protos/types.proto b/protos/types.proto index 83ae413a7..4d61884a7 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -37,6 +37,12 @@ message Function { bytes implementation = 1; } +message ReusableVar { + string name = 1; // The name of the reusable variable. + Function initializer = 2; // A serialized version of the function that initializes the reusable variable. + Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable. +} + // Union of possible object types message Obj { String string_data = 1; diff --git a/src/raylib.cc b/src/raylib.cc index 43efe393e..6629d617c 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -607,13 +607,13 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { return PyCapsule_New(static_cast(task), "task", &TaskCapsule_Destructor); } -static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) { +static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) { std::vector objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value. - PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size()); - int argsize = task->arg_size(); + PyObject* string = PyString_FromStringAndSize(task.name().c_str(), task.name().size()); + int argsize = task.arg_size(); PyObject* arglist = PyList_New(argsize); for (int i = 0; i < argsize; ++i) { - const Value& val = task->arg(i); + const Value& val = task.arg(i); if (!val.has_obj()) { PyList_SetItem(arglist, i, make_pyobjref(worker_capsule, val.ref())); objrefs.push_back(val.ref()); @@ -624,12 +624,12 @@ static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) { Worker* worker; PyObjectToWorker(worker_capsule, &worker); worker->decrement_reference_count(objrefs); - int resultsize = task->result_size(); + int resultsize = task.result_size(); std::vector result_objrefs; PyObject* resultlist = PyList_New(resultsize); for (int i = 0; i < resultsize; ++i) { - PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task->result(i))); - result_objrefs.push_back(task->result(i)); + PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task.result(i))); + result_objrefs.push_back(task.result(i)); } worker->decrement_reference_count(result_objrefs); // The corresponding increment is done in SubmitTask in the scheduler. PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. @@ -685,23 +685,32 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { Worker* worker; PyObjectToWorker(worker_capsule, &worker); if (std::unique_ptr message = worker->receive_next_message()) { - PyObject* variable_info; - if (!message->reusable_variable.variable_name.empty()) { - variable_info = PyTuple_New(3); - PyTuple_SetItem(variable_info, 0, PyString_FromStringAndSize(message->reusable_variable.variable_name.data(), static_cast(message->reusable_variable.variable_name.size()))); - PyTuple_SetItem(variable_info, 1, PyString_FromStringAndSize(message->reusable_variable.initializer.data(), static_cast(message->reusable_variable.initializer.size()))); - PyTuple_SetItem(variable_info, 2, PyString_FromStringAndSize(message->reusable_variable.reinitializer.data(), static_cast(message->reusable_variable.reinitializer.size()))); + bool task_present = !message->task().name().empty(); + bool function_present = !message->function().implementation().empty(); + bool reusable_variable_present = !message->reusable_variable().name().empty(); + RAY_CHECK(task_present + function_present + reusable_variable_present <= 1, "The worker message should contain at most one item."); + PyObject* t = PyTuple_New(2); + if (task_present) { + PyTuple_SetItem(t, 0, PyString_FromString("task")); + PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task())); + } else if (function_present) { + PyTuple_SetItem(t, 0, PyString_FromString("function")); + PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast(message->function().implementation().size()))); + } else if (reusable_variable_present) { + PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable")); + PyObject* reusable_variable = PyTuple_New(3); + PyTuple_SetItem(reusable_variable, 0, PyString_FromStringAndSize(message->reusable_variable().name().data(), static_cast(message->reusable_variable().name().size()))); + PyTuple_SetItem(reusable_variable, 1, PyString_FromStringAndSize(message->reusable_variable().initializer().implementation().data(), static_cast(message->reusable_variable().initializer().implementation().size()))); + PyTuple_SetItem(reusable_variable, 2, PyString_FromStringAndSize(message->reusable_variable().reinitializer().implementation().data(), static_cast(message->reusable_variable().reinitializer().implementation().size()))); + PyTuple_SetItem(t, 1, reusable_variable); + } else { + PyTuple_SetItem(t, 0, PyString_FromString("die")); + Py_INCREF(Py_None); + PyTuple_SetItem(t, 1, Py_None); } - // The tuple constructed below will take ownership of some None objects. - // When the tuple goes out of scope, the reference count for None will be - // decremented. Therefore, we need to increment the reference count for None - // every time we put a None in the tuple. - PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. - PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_INCREF(Py_None), Py_None : deserialize_task(worker_capsule, &message->task)); - PyTuple_SetItem(t, 1, message->function.empty() ? Py_INCREF(Py_None), Py_None : PyString_FromStringAndSize(message->function.data(), static_cast(message->function.size()))); - PyTuple_SetItem(t, 2, message->reusable_variable.variable_name.empty() ? Py_INCREF(Py_None), Py_None : variable_info); return t; } + RAY_CHECK(false, "This code should be unreachable."); Py_RETURN_NONE; } diff --git a/src/scheduler.cc b/src/scheduler.cc index 5e64cbe71..74f527964 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -150,6 +150,20 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN OperationId operationid = (*workers_.get())[workerid].current_task; RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task"); RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask."); + { + // Check if the worker has been initialized yet, and if not, then give it + // all of the exported functions and all of the exported reusable variables. + auto workers = workers_.get(); + if (!(*workers)[workerid].initialized) { + // This should only happen once. + // Import all remote functions on the worker. + export_all_functions_to_worker(workerid, workers, exported_functions_.get()); + // Import all reusable variables on the worker. + export_all_reusable_variables_to_worker(workerid, workers, exported_reusable_variables_.get()); + // Mark the worker as initialized. + (*workers)[workerid].initialized = true; + } + } if (request->has_previous_task_info()) { RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION."); std::string task_name; @@ -293,13 +307,12 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) { auto workers = workers_.get(); + auto exported_functions = exported_functions_.get(); + // TODO(rkn): Does this do a deep copy? + exported_functions->push_back(std::unique_ptr(new Function(request->function()))); for (size_t i = 0; i < workers->size(); ++i) { - ClientContext import_context; - ImportFunctionRequest import_request; - import_request.mutable_function()->set_implementation(request->function().implementation()); if ((*workers)[i].current_task != ROOT_OPERATION) { - ImportFunctionReply import_reply; - (*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply); + export_function_to_worker(i, exported_functions->size() - 1, workers, exported_functions); } } return Status::OK; @@ -307,15 +320,12 @@ Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunc Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) { auto workers = workers_.get(); + auto exported_reusable_variables = exported_reusable_variables_.get(); + // TODO(rkn): Does this do a deep copy? + exported_reusable_variables->push_back(std::unique_ptr(new ReusableVar(request->reusable_variable()))); for (size_t i = 0; i < workers->size(); ++i) { - ClientContext import_context; - ImportReusableVariableRequest import_request; - import_request.set_name(request->name()); - import_request.mutable_initializer()->set_implementation(request->initializer().implementation()); - import_request.mutable_reinitializer()->set_implementation(request->reinitializer().implementation()); if ((*workers)[i].current_task != ROOT_OPERATION) { - AckReply import_reply; - (*workers)[i].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply); + export_reusable_variable_to_worker(i, exported_reusable_variables->size() - 1, workers, exported_reusable_variables); } } return Status::OK; @@ -451,6 +461,7 @@ std::pair SchedulerService::register_worker(const std::str (*workers)[workerid].objstoreid = objstoreid; (*workers)[workerid].worker_stub = WorkerService::NewStub(channel); (*workers)[workerid].worker_address = worker_address; + (*workers)[workerid].initialized = false; if (is_driver) { (*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers. } else { @@ -830,6 +841,37 @@ void SchedulerService::get_equivalent_objrefs(ObjRef objref, std::vector upstream_objrefs(downstream_objref, equivalent_objrefs, reverse_target_objrefs_.get()); } + +void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_functions) { + RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid); + ClientContext import_context; + ImportFunctionRequest import_request; + import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get()); + ImportFunctionReply import_reply; + (*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply); +} + +void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_reusable_variables) { + RAY_LOG(RAY_INFO, "exporting reusable variable with index " << reusable_variable_index << " to worker " << workerid); + ClientContext import_context; + ImportReusableVariableRequest import_request; + import_request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get()); + AckReply import_reply; + (*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply); +} + +void SchedulerService::export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_functions) { + for (int i = 0; i < exported_functions->size(); ++i) { + export_function_to_worker(workerid, i, workers, exported_functions); + } +} + +void SchedulerService::export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_reusable_variables) { + for (int i = 0; i < exported_reusable_variables->size(); ++i) { + export_reusable_variable_to_worker(workerid, i, workers, exported_reusable_variables); + } +} + // This method defines the order in which locks should be acquired. void SchedulerService::do_on_locks(bool lock) { std::mutex *mutexes[] = { @@ -847,7 +889,9 @@ void SchedulerService::do_on_locks(bool lock) { &objtable_.mutex(), &objstores_.mutex(), &target_objrefs_.mutex(), - &reverse_target_objrefs_.mutex() + &reverse_target_objrefs_.mutex(), + &exported_functions_.mutex(), + &exported_reusable_variables_.mutex(), }; size_t n = sizeof(mutexes) / sizeof(*mutexes); for (size_t i = 0; i != n; ++i) { diff --git a/src/scheduler.h b/src/scheduler.h index 7f2a11d7e..08921ed69 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -37,6 +37,10 @@ struct WorkerHandle { std::unique_ptr worker_stub; // If null, the worker has died ObjStoreId objstoreid; std::string worker_address; + // This field is initialized to false, and it is set to true after all of the + // exported functions and exported reusable variables have been shipped to + // this worker. + bool initialized; OperationId current_task; }; @@ -129,6 +133,20 @@ private: void upstream_objrefs(ObjRef objref, std::vector &objrefs, const SynchronizedPtr > > &reverse_target_objrefs); // Find all of the object references that refer to the same object as objref (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known. void get_equivalent_objrefs(ObjRef objref, std::vector &equivalent_objrefs); + // Export a remote function to a worker. + void export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_functions); + // Export a reusable variable to a worker + void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_reusable_variables); + // Export all reusable variables to a worker. This is used when a new worker + // registers and is protected by the workers lock (which is passed in) to + // ensure that no other reusable variables are exported to the worker while + // this method is being called. + void export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_functions); + // Export all remote functions to a worker. This is used when a new worker + // registers and is protected by the workers lock (which is passed in) to + // ensure that no other remote functions are exported to the worker while this + // method is being called. + void export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr > &workers, const SynchronizedPtr > > &exported_reusable_variables); // acquires all locks, this should only be used by get_info and for fault tolerance void acquire_all_locks(); // release all locks, this should only be used by get_info and for fault tolerance @@ -187,6 +205,10 @@ private: Synchronized > reference_counts_; // contained_objrefs_[objref] is a vector of all of the objrefs contained inside the object referred to by objref Synchronized > > contained_objrefs_; + // All of the remote functions that have been exported to the workers. + Synchronized > > exported_functions_; + // All of the reusable variables that have been exported to the workers. + Synchronized > > exported_reusable_variables_; // the scheduling algorithm that will be used SchedulingAlgorithmType scheduling_algorithm_; }; diff --git a/src/worker.cc b/src/worker.cc index 3fa5e58e7..29081303e 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -17,7 +17,7 @@ inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address) Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) { RAY_LOG(RAY_INFO, "invoked task " << request->task().name()); std::unique_ptr message(new WorkerMessage()); - message->task = request->task(); + message->mutable_task()->CopyFrom(request->task()); { WorkerMessage* message_ptr = message.get(); RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); @@ -28,7 +28,7 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) { std::unique_ptr message(new WorkerMessage()); - message->function = request->function().implementation(); + message->mutable_function()->CopyFrom(request->function()); RAY_LOG(RAY_INFO, "importing function"); { WorkerMessage* message_ptr = message.get(); @@ -40,9 +40,7 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) { std::unique_ptr message(new WorkerMessage()); - message->reusable_variable.variable_name = request->name(); - message->reusable_variable.initializer = request->initializer().implementation(); - message->reusable_variable.reinitializer = request->reinitializer().implementation(); + message->mutable_reusable_variable()->CopyFrom(request->reusable_variable()); RAY_LOG(RAY_INFO, "importing reusable variable"); { WorkerMessage* message_ptr = message.get(); @@ -360,9 +358,9 @@ void Worker::export_reusable_variable(const std::string& name, const std::string RAY_CHECK(connected_, "Attempted to export reusable variable but failed."); ClientContext context; ExportReusableVariableRequest request; - request.set_name(name); - request.mutable_initializer()->set_implementation(initializer); - request.mutable_reinitializer()->set_implementation(reinitializer); + request.mutable_reusable_variable()->set_name(name); + request.mutable_reusable_variable()->mutable_initializer()->set_implementation(initializer); + request.mutable_reusable_variable()->mutable_reinitializer()->set_implementation(reinitializer); AckReply reply; Status status = scheduler_stub_->ExportReusableVariable(&context, request, &reply); } diff --git a/src/worker.h b/src/worker.h index 02cdd31d0..945550d87 100644 --- a/src/worker.h +++ b/src/worker.h @@ -23,18 +23,6 @@ using grpc::Channel; using grpc::ClientContext; using grpc::ClientWriter; -struct ReusableVariable { - std::string variable_name; - std::string initializer; - std::string reinitializer; -}; - -struct WorkerMessage { - Task task; - std::string function; // Used for importing remote functions. - ReusableVariable reusable_variable; // Used for importing reusable variables. -}; - class WorkerServiceImpl final : public WorkerService::Service { public: WorkerServiceImpl(const std::string& worker_address);