Properly import remote functions and reusable variables on workers that register late (#290)

This commit is contained in:
Robert Nishihara 2016-07-25 16:17:17 -07:00 committed by Philipp Moritz
parent 5591aa4665
commit 8e9f98c5ff
8 changed files with 145 additions and 76 deletions

View file

@ -686,31 +686,27 @@ def main_loop(worker=global_worker):
# above so that changes made to their state do not affect other tasks. # above so that changes made to their state do not affect other tasks.
ray.reusables._reinitialize() ray.reusables._reinitialize()
while True: while True:
(task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle) command, command_args = ray.lib.wait_for_next_message(worker.handle)
try: try:
# Only one of task, function, and reusable_variable should be not None. if command == "die":
assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1 # We use this as a mechanism to allow the scheduler to kill workers.
if task is None and function is None and reusable_variable is None:
# We use this as a mechanism to allow the scheduler to kill workers. When
# the scheduler wants to kill a worker, it gives the worker a null task,
# causing the worker program to exit the main loop here.
break break
if function is not None: elif command == "function":
(function, arg_types, return_types) = pickling.loads(function) (function, arg_types, return_types) = pickling.loads(command_args)
if function.__module__ is None: function.__module__ = "__main__" if function.__module__ is None: function.__module__ = "__main__"
worker.register_function(remote(arg_types, return_types, worker)(function)) worker.register_function(remote(arg_types, return_types, worker)(function))
if reusable_variable is not None: elif command == "reusable_variable":
name, initializer_str, reinitializer_str = reusable_variable name, initializer_str, reinitializer_str = command_args
initializer = pickling.loads(initializer_str) initializer = pickling.loads(initializer_str)
reinitializer = pickling.loads(reinitializer_str) reinitializer = pickling.loads(reinitializer_str)
reusables.__setattr__(name, Reusable(initializer, reinitializer)) reusables.__setattr__(name, Reusable(initializer, reinitializer))
if task is not None: elif command == "task":
process_task(task) process_task(command_args)
else:
assert False, "This code should be unreachable."
finally: finally:
# Allow releasing the variables BEFORE we wait for the next message or exit the block # Allow releasing the variables BEFORE we wait for the next message or exit the block
del task del command_args
del function
del reusable_variable
def _submit_task(func_name, args, worker=global_worker): def _submit_task(func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task. """This is a wrapper around worker.submit_task.

View file

@ -240,9 +240,7 @@ message ExportFunctionReply {
} }
message ExportReusableVariableRequest { message ExportReusableVariableRequest {
string name = 1; // The name of the reusable variable. ReusableVar reusable_variable = 1; // The reusable variable to export.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
} }
// These messages are for getting information about the object store state // These messages are for getting information about the object store state
@ -280,9 +278,7 @@ message ImportFunctionReply {
} }
message ImportReusableVariableRequest { message ImportReusableVariableRequest {
string name = 1; // The name of the reusable variable. ReusableVar reusable_variable = 1; // The reusable variable to export.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
} }
message DieRequest { message DieRequest {
@ -290,3 +286,13 @@ message DieRequest {
message DieReply { message DieReply {
} }
// This message is used by the worker service to send messages to the worker
// that are processed by the worker's main loop.
message WorkerMessage {
oneof worker_item {
Task task = 1; // A task for the worker to execute.
Function function = 2; // A remote function to import on the worker.
ReusableVar reusable_variable = 3; // A reusable variable to import on the worker.
}
}

View file

@ -37,6 +37,12 @@ message Function {
bytes implementation = 1; bytes implementation = 1;
} }
message ReusableVar {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
}
// Union of possible object types // Union of possible object types
message Obj { message Obj {
String string_data = 1; String string_data = 1;

View file

@ -607,13 +607,13 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor); return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
} }
static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) { static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) {
std::vector<ObjRef> objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value. std::vector<ObjRef> objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value.
PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size()); PyObject* string = PyString_FromStringAndSize(task.name().c_str(), task.name().size());
int argsize = task->arg_size(); int argsize = task.arg_size();
PyObject* arglist = PyList_New(argsize); PyObject* arglist = PyList_New(argsize);
for (int i = 0; i < argsize; ++i) { for (int i = 0; i < argsize; ++i) {
const Value& val = task->arg(i); const Value& val = task.arg(i);
if (!val.has_obj()) { if (!val.has_obj()) {
PyList_SetItem(arglist, i, make_pyobjref(worker_capsule, val.ref())); PyList_SetItem(arglist, i, make_pyobjref(worker_capsule, val.ref()));
objrefs.push_back(val.ref()); objrefs.push_back(val.ref());
@ -624,12 +624,12 @@ static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
Worker* worker; Worker* worker;
PyObjectToWorker(worker_capsule, &worker); PyObjectToWorker(worker_capsule, &worker);
worker->decrement_reference_count(objrefs); worker->decrement_reference_count(objrefs);
int resultsize = task->result_size(); int resultsize = task.result_size();
std::vector<ObjRef> result_objrefs; std::vector<ObjRef> result_objrefs;
PyObject* resultlist = PyList_New(resultsize); PyObject* resultlist = PyList_New(resultsize);
for (int i = 0; i < resultsize; ++i) { for (int i = 0; i < resultsize; ++i) {
PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task->result(i))); PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task.result(i)));
result_objrefs.push_back(task->result(i)); result_objrefs.push_back(task.result(i));
} }
worker->decrement_reference_count(result_objrefs); // The corresponding increment is done in SubmitTask in the scheduler. worker->decrement_reference_count(result_objrefs); // The corresponding increment is done in SubmitTask in the scheduler.
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
@ -685,23 +685,32 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
Worker* worker; Worker* worker;
PyObjectToWorker(worker_capsule, &worker); PyObjectToWorker(worker_capsule, &worker);
if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) { if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) {
PyObject* variable_info; bool task_present = !message->task().name().empty();
if (!message->reusable_variable.variable_name.empty()) { bool function_present = !message->function().implementation().empty();
variable_info = PyTuple_New(3); bool reusable_variable_present = !message->reusable_variable().name().empty();
PyTuple_SetItem(variable_info, 0, PyString_FromStringAndSize(message->reusable_variable.variable_name.data(), static_cast<ssize_t>(message->reusable_variable.variable_name.size()))); RAY_CHECK(task_present + function_present + reusable_variable_present <= 1, "The worker message should contain at most one item.");
PyTuple_SetItem(variable_info, 1, PyString_FromStringAndSize(message->reusable_variable.initializer.data(), static_cast<ssize_t>(message->reusable_variable.initializer.size()))); PyObject* t = PyTuple_New(2);
PyTuple_SetItem(variable_info, 2, PyString_FromStringAndSize(message->reusable_variable.reinitializer.data(), static_cast<ssize_t>(message->reusable_variable.reinitializer.size()))); if (task_present) {
PyTuple_SetItem(t, 0, PyString_FromString("task"));
PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task()));
} else if (function_present) {
PyTuple_SetItem(t, 0, PyString_FromString("function"));
PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
} else if (reusable_variable_present) {
PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable"));
PyObject* reusable_variable = PyTuple_New(3);
PyTuple_SetItem(reusable_variable, 0, PyString_FromStringAndSize(message->reusable_variable().name().data(), static_cast<ssize_t>(message->reusable_variable().name().size())));
PyTuple_SetItem(reusable_variable, 1, PyString_FromStringAndSize(message->reusable_variable().initializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().initializer().implementation().size())));
PyTuple_SetItem(reusable_variable, 2, PyString_FromStringAndSize(message->reusable_variable().reinitializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().reinitializer().implementation().size())));
PyTuple_SetItem(t, 1, reusable_variable);
} else {
PyTuple_SetItem(t, 0, PyString_FromString("die"));
Py_INCREF(Py_None);
PyTuple_SetItem(t, 1, Py_None);
} }
// The tuple constructed below will take ownership of some None objects.
// When the tuple goes out of scope, the reference count for None will be
// decremented. Therefore, we need to increment the reference count for None
// every time we put a None in the tuple.
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_INCREF(Py_None), Py_None : deserialize_task(worker_capsule, &message->task));
PyTuple_SetItem(t, 1, message->function.empty() ? Py_INCREF(Py_None), Py_None : PyString_FromStringAndSize(message->function.data(), static_cast<ssize_t>(message->function.size())));
PyTuple_SetItem(t, 2, message->reusable_variable.variable_name.empty() ? Py_INCREF(Py_None), Py_None : variable_info);
return t; return t;
} }
RAY_CHECK(false, "This code should be unreachable.");
Py_RETURN_NONE; Py_RETURN_NONE;
} }

View file

@ -150,6 +150,20 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN
OperationId operationid = (*workers_.get())[workerid].current_task; OperationId operationid = (*workers_.get())[workerid].current_task;
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task"); RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask."); RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
{
// Check if the worker has been initialized yet, and if not, then give it
// all of the exported functions and all of the exported reusable variables.
auto workers = workers_.get();
if (!(*workers)[workerid].initialized) {
// This should only happen once.
// Import all remote functions on the worker.
export_all_functions_to_worker(workerid, workers, exported_functions_.get());
// Import all reusable variables on the worker.
export_all_reusable_variables_to_worker(workerid, workers, exported_reusable_variables_.get());
// Mark the worker as initialized.
(*workers)[workerid].initialized = true;
}
}
if (request->has_previous_task_info()) { if (request->has_previous_task_info()) {
RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION."); RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION.");
std::string task_name; std::string task_name;
@ -293,13 +307,12 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) { Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
auto workers = workers_.get(); auto workers = workers_.get();
auto exported_functions = exported_functions_.get();
// TODO(rkn): Does this do a deep copy?
exported_functions->push_back(std::unique_ptr<Function>(new Function(request->function())));
for (size_t i = 0; i < workers->size(); ++i) { for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->set_implementation(request->function().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) { if ((*workers)[i].current_task != ROOT_OPERATION) {
ImportFunctionReply import_reply; export_function_to_worker(i, exported_functions->size() - 1, workers, exported_functions);
(*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
} }
} }
return Status::OK; return Status::OK;
@ -307,15 +320,12 @@ Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunc
Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) { Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) {
auto workers = workers_.get(); auto workers = workers_.get();
auto exported_reusable_variables = exported_reusable_variables_.get();
// TODO(rkn): Does this do a deep copy?
exported_reusable_variables->push_back(std::unique_ptr<ReusableVar>(new ReusableVar(request->reusable_variable())));
for (size_t i = 0; i < workers->size(); ++i) { for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.set_name(request->name());
import_request.mutable_initializer()->set_implementation(request->initializer().implementation());
import_request.mutable_reinitializer()->set_implementation(request->reinitializer().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) { if ((*workers)[i].current_task != ROOT_OPERATION) {
AckReply import_reply; export_reusable_variable_to_worker(i, exported_reusable_variables->size() - 1, workers, exported_reusable_variables);
(*workers)[i].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
} }
} }
return Status::OK; return Status::OK;
@ -451,6 +461,7 @@ std::pair<WorkerId, ObjStoreId> SchedulerService::register_worker(const std::str
(*workers)[workerid].objstoreid = objstoreid; (*workers)[workerid].objstoreid = objstoreid;
(*workers)[workerid].worker_stub = WorkerService::NewStub(channel); (*workers)[workerid].worker_stub = WorkerService::NewStub(channel);
(*workers)[workerid].worker_address = worker_address; (*workers)[workerid].worker_address = worker_address;
(*workers)[workerid].initialized = false;
if (is_driver) { if (is_driver) {
(*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers. (*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers.
} else { } else {
@ -830,6 +841,37 @@ void SchedulerService::get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef>
upstream_objrefs(downstream_objref, equivalent_objrefs, reverse_target_objrefs_.get()); upstream_objrefs(downstream_objref, equivalent_objrefs, reverse_target_objrefs_.get());
} }
void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid);
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
ImportFunctionReply import_reply;
(*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
}
void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
RAY_LOG(RAY_INFO, "exporting reusable variable with index " << reusable_variable_index << " to worker " << workerid);
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get());
AckReply import_reply;
(*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
}
void SchedulerService::export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
for (int i = 0; i < exported_functions->size(); ++i) {
export_function_to_worker(workerid, i, workers, exported_functions);
}
}
void SchedulerService::export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
for (int i = 0; i < exported_reusable_variables->size(); ++i) {
export_reusable_variable_to_worker(workerid, i, workers, exported_reusable_variables);
}
}
// This method defines the order in which locks should be acquired. // This method defines the order in which locks should be acquired.
void SchedulerService::do_on_locks(bool lock) { void SchedulerService::do_on_locks(bool lock) {
std::mutex *mutexes[] = { std::mutex *mutexes[] = {
@ -847,7 +889,9 @@ void SchedulerService::do_on_locks(bool lock) {
&objtable_.mutex(), &objtable_.mutex(),
&objstores_.mutex(), &objstores_.mutex(),
&target_objrefs_.mutex(), &target_objrefs_.mutex(),
&reverse_target_objrefs_.mutex() &reverse_target_objrefs_.mutex(),
&exported_functions_.mutex(),
&exported_reusable_variables_.mutex(),
}; };
size_t n = sizeof(mutexes) / sizeof(*mutexes); size_t n = sizeof(mutexes) / sizeof(*mutexes);
for (size_t i = 0; i != n; ++i) { for (size_t i = 0; i != n; ++i) {

View file

@ -37,6 +37,10 @@ struct WorkerHandle {
std::unique_ptr<WorkerService::Stub> worker_stub; // If null, the worker has died std::unique_ptr<WorkerService::Stub> worker_stub; // If null, the worker has died
ObjStoreId objstoreid; ObjStoreId objstoreid;
std::string worker_address; std::string worker_address;
// This field is initialized to false, and it is set to true after all of the
// exported functions and exported reusable variables have been shipped to
// this worker.
bool initialized;
OperationId current_task; OperationId current_task;
}; };
@ -129,6 +133,20 @@ private:
void upstream_objrefs(ObjRef objref, std::vector<ObjRef> &objrefs, const SynchronizedPtr<std::vector<std::vector<ObjRef> > > &reverse_target_objrefs); void upstream_objrefs(ObjRef objref, std::vector<ObjRef> &objrefs, const SynchronizedPtr<std::vector<std::vector<ObjRef> > > &reverse_target_objrefs);
// Find all of the object references that refer to the same object as objref (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known. // Find all of the object references that refer to the same object as objref (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known.
void get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef> &equivalent_objrefs); void get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef> &equivalent_objrefs);
// Export a remote function to a worker.
void export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export a reusable variable to a worker
void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// Export all reusable variables to a worker. This is used when a new worker
// registers and is protected by the workers lock (which is passed in) to
// ensure that no other reusable variables are exported to the worker while
// this method is being called.
void export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export all remote functions to a worker. This is used when a new worker
// registers and is protected by the workers lock (which is passed in) to
// ensure that no other remote functions are exported to the worker while this
// method is being called.
void export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// acquires all locks, this should only be used by get_info and for fault tolerance // acquires all locks, this should only be used by get_info and for fault tolerance
void acquire_all_locks(); void acquire_all_locks();
// release all locks, this should only be used by get_info and for fault tolerance // release all locks, this should only be used by get_info and for fault tolerance
@ -187,6 +205,10 @@ private:
Synchronized<std::vector<RefCount> > reference_counts_; Synchronized<std::vector<RefCount> > reference_counts_;
// contained_objrefs_[objref] is a vector of all of the objrefs contained inside the object referred to by objref // contained_objrefs_[objref] is a vector of all of the objrefs contained inside the object referred to by objref
Synchronized<std::vector<std::vector<ObjRef> > > contained_objrefs_; Synchronized<std::vector<std::vector<ObjRef> > > contained_objrefs_;
// All of the remote functions that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<Function> > > exported_functions_;
// All of the reusable variables that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<ReusableVar> > > exported_reusable_variables_;
// the scheduling algorithm that will be used // the scheduling algorithm that will be used
SchedulingAlgorithmType scheduling_algorithm_; SchedulingAlgorithmType scheduling_algorithm_;
}; };

View file

@ -17,7 +17,7 @@ inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address)
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) { Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
RAY_LOG(RAY_INFO, "invoked task " << request->task().name()); RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
std::unique_ptr<WorkerMessage> message(new WorkerMessage()); std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->task = request->task(); message->mutable_task()->CopyFrom(request->task());
{ {
WorkerMessage* message_ptr = message.get(); WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
@ -28,7 +28,7 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR
Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) { Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) {
std::unique_ptr<WorkerMessage> message(new WorkerMessage()); std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->function = request->function().implementation(); message->mutable_function()->CopyFrom(request->function());
RAY_LOG(RAY_INFO, "importing function"); RAY_LOG(RAY_INFO, "importing function");
{ {
WorkerMessage* message_ptr = message.get(); WorkerMessage* message_ptr = message.get();
@ -40,9 +40,7 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun
Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) { Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) {
std::unique_ptr<WorkerMessage> message(new WorkerMessage()); std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->reusable_variable.variable_name = request->name(); message->mutable_reusable_variable()->CopyFrom(request->reusable_variable());
message->reusable_variable.initializer = request->initializer().implementation();
message->reusable_variable.reinitializer = request->reinitializer().implementation();
RAY_LOG(RAY_INFO, "importing reusable variable"); RAY_LOG(RAY_INFO, "importing reusable variable");
{ {
WorkerMessage* message_ptr = message.get(); WorkerMessage* message_ptr = message.get();
@ -360,9 +358,9 @@ void Worker::export_reusable_variable(const std::string& name, const std::string
RAY_CHECK(connected_, "Attempted to export reusable variable but failed."); RAY_CHECK(connected_, "Attempted to export reusable variable but failed.");
ClientContext context; ClientContext context;
ExportReusableVariableRequest request; ExportReusableVariableRequest request;
request.set_name(name); request.mutable_reusable_variable()->set_name(name);
request.mutable_initializer()->set_implementation(initializer); request.mutable_reusable_variable()->mutable_initializer()->set_implementation(initializer);
request.mutable_reinitializer()->set_implementation(reinitializer); request.mutable_reusable_variable()->mutable_reinitializer()->set_implementation(reinitializer);
AckReply reply; AckReply reply;
Status status = scheduler_stub_->ExportReusableVariable(&context, request, &reply); Status status = scheduler_stub_->ExportReusableVariable(&context, request, &reply);
} }

View file

@ -23,18 +23,6 @@ using grpc::Channel;
using grpc::ClientContext; using grpc::ClientContext;
using grpc::ClientWriter; using grpc::ClientWriter;
struct ReusableVariable {
std::string variable_name;
std::string initializer;
std::string reinitializer;
};
struct WorkerMessage {
Task task;
std::string function; // Used for importing remote functions.
ReusableVariable reusable_variable; // Used for importing reusable variables.
};
class WorkerServiceImpl final : public WorkerService::Service { class WorkerServiceImpl final : public WorkerService::Service {
public: public:
WorkerServiceImpl(const std::string& worker_address); WorkerServiceImpl(const std::string& worker_address);