diff --git a/protos/ray.proto b/protos/ray.proto index 4e1416977..455703821 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -82,6 +82,7 @@ message SubmitTaskRequest { message SubmitTaskReply { repeated uint64 result = 1; // Object references of the function return values + bool function_registered = 2; // True if the function was registered; false otherwise } message RequestObjRequest { diff --git a/src/raylib.cc b/src/raylib.cc index b57d3b8c6..e6155429c 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -648,6 +648,10 @@ PyObject* submit_task(PyObject* self, PyObject* args) { SubmitTaskRequest request; request.set_allocated_task(task); SubmitTaskReply reply = worker->submit_task(&request); + if (!reply.function_registered()) { + PyErr_SetString(RayError, "task: function not registered"); + return NULL; + } request.release_task(); // TODO: Make sure that task is not moved, otherwise capsule pointer needs to be updated int size = reply.result_size(); PyObject* list = PyList_New(size); diff --git a/src/scheduler.cc b/src/scheduler.cc index df5965221..09f889d34 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -10,40 +10,46 @@ SchedulerService::SchedulerService(SchedulingAlgorithmType scheduling_algorithm) Status SchedulerService::SubmitTask(ServerContext* context, const SubmitTaskRequest* request, SubmitTaskReply* reply) { std::unique_ptr task(new Task(request->task())); // need to copy, because request is const - fntable_lock_.lock(); - - // TODO(rkn): In the future, this should probably not be fatal. Instead, propagate the error back to the worker. - RAY_CHECK_NEQ(fntable_.find(task->name()), fntable_.end(), "The function " << task->name() << " has not been registered by any worker."); - - size_t num_return_vals = fntable_[task->name()].num_return_vals(); - fntable_lock_.unlock(); - - std::vector result_objrefs; - for (size_t i = 0; i < num_return_vals; ++i) { - ObjRef result = register_new_object(); - reply->add_result(result); - task->add_result(result); - result_objrefs.push_back(result); - } + size_t num_return_vals; { - std::lock_guard reference_counts_lock(reference_counts_lock_); // we grab this lock because increment_ref_count assumes it has been acquired - increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before we reply to the worker that called SubmitTask. The corresponding decrement will happen in submit_task in raylib. - increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before the task is scheduled on the worker. The corresponding decrement will happen in deserialize_task in raylib. + std::lock_guard fntable_lock(fntable_lock_); + FnTable::const_iterator fn = fntable_.find(task->name()); + if (fn == fntable_.end()) { + num_return_vals = 0; + reply->set_function_registered(false); + } else { + num_return_vals = fn->second.num_return_vals(); + reply->set_function_registered(true); + } } + if (reply->function_registered()) { + std::vector result_objrefs; + for (size_t i = 0; i < num_return_vals; ++i) { + ObjRef result = register_new_object(); + reply->add_result(result); + task->add_result(result); + result_objrefs.push_back(result); + } + { + std::lock_guard reference_counts_lock(reference_counts_lock_); // we grab this lock because increment_ref_count assumes it has been acquired + increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before we reply to the worker that called SubmitTask. The corresponding decrement will happen in submit_task in raylib. + increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before the task is scheduled on the worker. The corresponding decrement will happen in deserialize_task in raylib. + } - auto operation = std::unique_ptr(new Operation()); - operation->set_allocated_task(task.release()); - OperationId creator_operationid = ROOT_OPERATION; // TODO(rkn): Later, this should be the ID of the task that spawned this current task. - operation->set_creator_operationid(creator_operationid); - computation_graph_lock_.lock(); - OperationId operationid = computation_graph_.add_operation(std::move(operation)); - computation_graph_lock_.unlock(); + auto operation = std::unique_ptr(new Operation()); + operation->set_allocated_task(task.release()); + OperationId creator_operationid = ROOT_OPERATION; // TODO(rkn): Later, this should be the ID of the task that spawned this current task. + operation->set_creator_operationid(creator_operationid); + computation_graph_lock_.lock(); + OperationId operationid = computation_graph_.add_operation(std::move(operation)); + computation_graph_lock_.unlock(); - task_queue_lock_.lock(); - task_queue_.push_back(operationid); - task_queue_lock_.unlock(); + task_queue_lock_.lock(); + task_queue_.push_back(operationid); + task_queue_lock_.unlock(); - schedule(); + schedule(); + } return Status::OK; } diff --git a/src/worker.cc b/src/worker.cc index c3770167c..7654e647a 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -1,5 +1,8 @@ #include "worker.h" +#include +#include + #include "utils.h" #include @@ -24,11 +27,19 @@ Worker::Worker(const std::string& worker_address, std::shared_ptr sched connected_ = true; } -SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request) { +SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, int retry_wait_milliseconds) { RAY_CHECK(connected_, "Attempted to perform submit_task but failed."); SubmitTaskReply reply; - ClientContext context; - Status status = scheduler_stub_->SubmitTask(&context, *request, &reply); + Status status; + for (int i = 0; i < 1 + max_retries; ++i) { + ClientContext context; + status = scheduler_stub_->SubmitTask(&context, *request, &reply); + if (reply.function_registered()) { + break; + } + RAY_LOG(RAY_INFO, "The function " << request->task().name() << " was not registered, so attempting to resubmit the task."); + std::this_thread::sleep_for(std::chrono::milliseconds(retry_wait_milliseconds)); + } return reply; } diff --git a/src/worker.h b/src/worker.h index 746ad290d..a6658a187 100644 --- a/src/worker.h +++ b/src/worker.h @@ -40,8 +40,10 @@ class Worker { public: Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel); - // submit a remote task to the scheduler - SubmitTaskReply submit_task(SubmitTaskRequest* request); + // Submit a remote task to the scheduler. If the function in the task is not + // registered with the scheduler, we will sleep for retry_wait_milliseconds + // and try to resubmit the task to the scheduler up to max_retries more times. + SubmitTaskReply submit_task(SubmitTaskRequest* request, int max_retries = 120, int retry_wait_milliseconds = 500); // send request to the scheduler to register this worker void register_worker(const std::string& worker_address, const std::string& objstore_address); // get a new object reference that is registered with the scheduler