mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Merge pull request #108 from amplab/unregistered_driver_fix
Unregistered function handling
This commit is contained in:
commit
67ce2d9837
5 changed files with 58 additions and 34 deletions
|
@ -82,6 +82,7 @@ message SubmitTaskRequest {
|
|||
|
||||
message SubmitTaskReply {
|
||||
repeated uint64 result = 1; // Object references of the function return values
|
||||
bool function_registered = 2; // True if the function was registered; false otherwise
|
||||
}
|
||||
|
||||
message RequestObjRequest {
|
||||
|
|
|
@ -648,6 +648,10 @@ PyObject* submit_task(PyObject* self, PyObject* args) {
|
|||
SubmitTaskRequest request;
|
||||
request.set_allocated_task(task);
|
||||
SubmitTaskReply reply = worker->submit_task(&request);
|
||||
if (!reply.function_registered()) {
|
||||
PyErr_SetString(RayError, "task: function not registered");
|
||||
return NULL;
|
||||
}
|
||||
request.release_task(); // TODO: Make sure that task is not moved, otherwise capsule pointer needs to be updated
|
||||
int size = reply.result_size();
|
||||
PyObject* list = PyList_New(size);
|
||||
|
|
|
@ -10,40 +10,46 @@ SchedulerService::SchedulerService(SchedulingAlgorithmType scheduling_algorithm)
|
|||
|
||||
Status SchedulerService::SubmitTask(ServerContext* context, const SubmitTaskRequest* request, SubmitTaskReply* reply) {
|
||||
std::unique_ptr<Task> task(new Task(request->task())); // need to copy, because request is const
|
||||
fntable_lock_.lock();
|
||||
|
||||
// TODO(rkn): In the future, this should probably not be fatal. Instead, propagate the error back to the worker.
|
||||
RAY_CHECK_NEQ(fntable_.find(task->name()), fntable_.end(), "The function " << task->name() << " has not been registered by any worker.");
|
||||
|
||||
size_t num_return_vals = fntable_[task->name()].num_return_vals();
|
||||
fntable_lock_.unlock();
|
||||
|
||||
std::vector<ObjRef> result_objrefs;
|
||||
for (size_t i = 0; i < num_return_vals; ++i) {
|
||||
ObjRef result = register_new_object();
|
||||
reply->add_result(result);
|
||||
task->add_result(result);
|
||||
result_objrefs.push_back(result);
|
||||
}
|
||||
size_t num_return_vals;
|
||||
{
|
||||
std::lock_guard<std::mutex> reference_counts_lock(reference_counts_lock_); // we grab this lock because increment_ref_count assumes it has been acquired
|
||||
increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before we reply to the worker that called SubmitTask. The corresponding decrement will happen in submit_task in raylib.
|
||||
increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before the task is scheduled on the worker. The corresponding decrement will happen in deserialize_task in raylib.
|
||||
std::lock_guard<std::mutex> fntable_lock(fntable_lock_);
|
||||
FnTable::const_iterator fn = fntable_.find(task->name());
|
||||
if (fn == fntable_.end()) {
|
||||
num_return_vals = 0;
|
||||
reply->set_function_registered(false);
|
||||
} else {
|
||||
num_return_vals = fn->second.num_return_vals();
|
||||
reply->set_function_registered(true);
|
||||
}
|
||||
}
|
||||
if (reply->function_registered()) {
|
||||
std::vector<ObjRef> result_objrefs;
|
||||
for (size_t i = 0; i < num_return_vals; ++i) {
|
||||
ObjRef result = register_new_object();
|
||||
reply->add_result(result);
|
||||
task->add_result(result);
|
||||
result_objrefs.push_back(result);
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> reference_counts_lock(reference_counts_lock_); // we grab this lock because increment_ref_count assumes it has been acquired
|
||||
increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before we reply to the worker that called SubmitTask. The corresponding decrement will happen in submit_task in raylib.
|
||||
increment_ref_count(result_objrefs); // We increment once so the objrefs don't go out of scope before the task is scheduled on the worker. The corresponding decrement will happen in deserialize_task in raylib.
|
||||
}
|
||||
|
||||
auto operation = std::unique_ptr<Operation>(new Operation());
|
||||
operation->set_allocated_task(task.release());
|
||||
OperationId creator_operationid = ROOT_OPERATION; // TODO(rkn): Later, this should be the ID of the task that spawned this current task.
|
||||
operation->set_creator_operationid(creator_operationid);
|
||||
computation_graph_lock_.lock();
|
||||
OperationId operationid = computation_graph_.add_operation(std::move(operation));
|
||||
computation_graph_lock_.unlock();
|
||||
auto operation = std::unique_ptr<Operation>(new Operation());
|
||||
operation->set_allocated_task(task.release());
|
||||
OperationId creator_operationid = ROOT_OPERATION; // TODO(rkn): Later, this should be the ID of the task that spawned this current task.
|
||||
operation->set_creator_operationid(creator_operationid);
|
||||
computation_graph_lock_.lock();
|
||||
OperationId operationid = computation_graph_.add_operation(std::move(operation));
|
||||
computation_graph_lock_.unlock();
|
||||
|
||||
task_queue_lock_.lock();
|
||||
task_queue_.push_back(operationid);
|
||||
task_queue_lock_.unlock();
|
||||
task_queue_lock_.lock();
|
||||
task_queue_.push_back(operationid);
|
||||
task_queue_lock_.unlock();
|
||||
|
||||
schedule();
|
||||
schedule();
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#include "worker.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <pynumbuf/serialize.h>
|
||||
|
@ -24,11 +27,19 @@ Worker::Worker(const std::string& worker_address, std::shared_ptr<Channel> sched
|
|||
connected_ = true;
|
||||
}
|
||||
|
||||
SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request) {
|
||||
SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, int retry_wait_milliseconds) {
|
||||
RAY_CHECK(connected_, "Attempted to perform submit_task but failed.");
|
||||
SubmitTaskReply reply;
|
||||
ClientContext context;
|
||||
Status status = scheduler_stub_->SubmitTask(&context, *request, &reply);
|
||||
Status status;
|
||||
for (int i = 0; i < 1 + max_retries; ++i) {
|
||||
ClientContext context;
|
||||
status = scheduler_stub_->SubmitTask(&context, *request, &reply);
|
||||
if (reply.function_registered()) {
|
||||
break;
|
||||
}
|
||||
RAY_LOG(RAY_INFO, "The function " << request->task().name() << " was not registered, so attempting to resubmit the task.");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(retry_wait_milliseconds));
|
||||
}
|
||||
return reply;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,8 +40,10 @@ class Worker {
|
|||
public:
|
||||
Worker(const std::string& worker_address, std::shared_ptr<Channel> scheduler_channel, std::shared_ptr<Channel> objstore_channel);
|
||||
|
||||
// submit a remote task to the scheduler
|
||||
SubmitTaskReply submit_task(SubmitTaskRequest* request);
|
||||
// Submit a remote task to the scheduler. If the function in the task is not
|
||||
// registered with the scheduler, we will sleep for retry_wait_milliseconds
|
||||
// and try to resubmit the task to the scheduler up to max_retries more times.
|
||||
SubmitTaskReply submit_task(SubmitTaskRequest* request, int max_retries = 120, int retry_wait_milliseconds = 500);
|
||||
// send request to the scheduler to register this worker
|
||||
void register_worker(const std::string& worker_address, const std::string& objstore_address);
|
||||
// get a new object reference that is registered with the scheduler
|
||||
|
|
Loading…
Add table
Reference in a new issue