Function serialization (#261)

This commit is contained in:
mehrdadn 2016-07-17 22:05:07 -07:00 committed by Robert Nishihara
parent 8e0ecfa1f4
commit f2c43bec87
15 changed files with 180 additions and 43 deletions

View file

@ -19,7 +19,7 @@ brew update
brew install git cmake automake autoconf libtool boost graphviz
sudo easy_install pip
sudo pip install ipython --user
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz --ignore-installed six
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle --ignore-installed six
```
### Build

View file

@ -15,7 +15,7 @@ First install the dependencies. We currently do not support Python 3.
```
sudo apt-get update
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle
```
### Build

View file

@ -31,11 +31,11 @@ if [[ $platform == "linux" ]]; then
# These commands must be kept in sync with the installation instructions.
sudo apt-get update
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz
sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle
elif [[ $platform == "macosx" ]]; then
# These commands must be kept in sync with the installation instructions.
brew install git cmake automake autoconf libtool boost graphviz
sudo easy_install pip
sudo pip install ipython --user
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0-alpha-2 colorama graphviz --ignore-installed six
sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0-alpha-2 colorama graphviz cloudpickle --ignore-installed six
fi

View file

@ -0,0 +1,7 @@
import cloudpickle
def dumps(func, arg_types, return_types):
return cloudpickle.dumps((func, arg_types, return_types))
def loads(function):
return cloudpickle.loads(function)

View file

@ -67,6 +67,6 @@ def serialize_task(worker_capsule, func_name, args):
return ray.lib.serialize_task(worker_capsule, func_name, primitive_args)
def deserialize_task(worker_capsule, task):
func_name, primitive_args, return_objrefs = ray.lib.deserialize_task(worker_capsule, task)
func_name, primitive_args, return_objrefs = task
args = [(arg if isinstance(arg, ray.lib.ObjRef) else from_primitive(arg)) for arg in primitive_args]
return func_name, args, return_objrefs

View file

@ -10,6 +10,7 @@ import numpy as np
import colorama
import ray
import pickling
import serialization
import ray.internal.graph_pb2
import ray.graph
@ -120,6 +121,7 @@ class Worker(object):
"""Initialize a Worker object."""
self.functions = {}
self.handle = None
self.mode = None
def set_mode(self, mode):
"""Set the mode of the worker.
@ -510,13 +512,26 @@ def main_loop(worker=global_worker):
store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store
ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
while True:
task = ray.lib.wait_for_next_task(worker.handle)
if task is None:
# We use this as a mechanism to allow the scheduler to kill workers. When
# the scheduler wants to kill a worker, it gives the worker a null task,
# causing the worker program to exit the main loop here.
break
process_task(task)
(task, function) = ray.lib.wait_for_next_message(worker.handle)
try:
# Currently the schedule does not ask the worker to execute a task and
# import a function at the same time.
assert task is None or function is None
if task is None and function is None:
# We use this as a mechanism to allow the scheduler to kill workers. When
# the scheduler wants to kill a worker, it gives the worker a null task,
# causing the worker program to exit the main loop here.
break
if function is not None:
(function, arg_types, return_types) = pickling.loads(function)
if function.__module__ is None: function.__module__ = "__main__"
worker.register_function(remote(arg_types, return_types, worker)(function))
if task is not None:
process_task(task)
finally:
# Allow releasing the variables BEFORE we wait for the next message or exit the block
del task
del function
def remote(arg_types, return_types, worker=global_worker):
"""This decorator is used to create remote functions.
@ -526,6 +541,7 @@ def remote(arg_types, return_types, worker=global_worker):
return_types (List[type]): List of Python types of the return values.
"""
def remote_decorator(func):
to_export = pickling.dumps(func, arg_types, return_types) if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE] else None
def func_executor(arguments):
"""This gets run when the remote function is executed."""
logging.info("Calling function {}".format(func.__name__))
@ -560,6 +576,8 @@ def remote(arg_types, return_types, worker=global_worker):
func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params])
func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params])
check_signature_supported(func_call)
if to_export is not None:
ray.lib.export_function(worker.handle, to_export)
return func_call
return remote_decorator

View file

@ -3,7 +3,7 @@ syntax = "proto3";
import "types.proto";
message Task {
string name = 1; // Name of the function call
string name = 1; // Name of the function call. Must not be empty.
repeated Value arg = 2; // List of arguments, can be either object references or protobuf descriptions of object passed by value
repeated uint64 result = 3; // Object references for result
}

View file

@ -52,6 +52,8 @@ service Scheduler {
rpc TaskInfo(TaskInfoRequest) returns (TaskInfoReply);
// Kills the workers
rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply);
// Exports function to the workers
rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply);
}
message AckReply {
@ -228,6 +230,13 @@ message KillWorkersReply {
bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks
}
message ExportFunctionRequest {
Function function = 1;
}
message ExportFunctionReply {
}
// These messages are for getting information about the object store state
message ObjStoreInfoRequest {
@ -243,6 +252,7 @@ message ObjStoreInfoReply {
service WorkerService {
rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker
rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker
rpc Die(DieRequest) returns (DieReply); // Kills this worker
}
@ -253,6 +263,13 @@ message ExecuteTaskRequest {
message ExecuteTaskReply {
}
message ImportFunctionRequest {
Function function = 1;
}
message ImportFunctionReply {
}
message DieRequest {
}

View file

@ -32,6 +32,11 @@ message PyObj {
bytes data = 1;
}
// Used for shipping remote functions to workers
message Function {
bytes implementation = 1;
}
// Union of possible object types
message Obj {
String string_data = 1;

View file

@ -592,12 +592,7 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
}
static PyObject* deserialize_task(PyObject* self, PyObject* args) {
PyObject* worker_capsule;
Task* task;
if (!PyArg_ParseTuple(args, "OO&", &worker_capsule, &PyObjectToTask, &task)) {
return NULL;
}
static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
std::vector<ObjRef> objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value.
PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size());
int argsize = task->arg_size();
@ -667,19 +662,36 @@ static PyObject* connected(PyObject* self, PyObject* args) {
Py_RETURN_FALSE;
}
static PyObject* wait_for_next_task(PyObject* self, PyObject* args) {
Worker* worker;
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
PyObject* worker_capsule;
if (!PyArg_ParseTuple(args, "O", &worker_capsule)) {
return NULL;
}
if (std::unique_ptr<Task> task = worker->receive_next_task()) {
PyObject* pyobj = PyCapsule_New(task.get(), "task", TaskCapsule_Destructor);
task.release(); // Now that the wrapper object was constructed successfully, release ownership
return pyobj;
Worker* worker;
PyObjectToWorker(worker_capsule, &worker);
if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) {
PyObject* t = PyTuple_New(2); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_None : deserialize_task(worker_capsule, &message->task));
PyTuple_SetItem(t, 1, message->function.empty() ? Py_None : PyString_FromStringAndSize(message->function.data(), static_cast<ssize_t>(message->function.size())));
return t;
}
Py_RETURN_NONE;
}
static PyObject* export_function(PyObject* self, PyObject* args) {
Worker* worker;
const char* function;
int function_size;
if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) {
return NULL;
}
if (worker->export_function(std::string(function, static_cast<size_t>(function_size)))) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
static PyObject* submit_task(PyObject* self, PyObject* args) {
PyObject* worker_capsule;
Task* task;
@ -918,7 +930,6 @@ static PyMethodDef RayLibMethods[] = {
{ "is_arrow", is_arrow, METH_VARARGS, "is the object in the local object store an arrow object?"},
{ "unmap_object", unmap_object, METH_VARARGS, "unmap the object from the client's shared memory pool"},
{ "serialize_task", serialize_task, METH_VARARGS, "serialize a task to protocol buffers" },
{ "deserialize_task", deserialize_task, METH_VARARGS, "deserialize a task from protocol buffers" },
{ "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" },
{ "disconnect", disconnect, METH_VARARGS, "disconnect the worker from the scheduler and the object store" },
{ "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" },
@ -928,12 +939,13 @@ static PyMethodDef RayLibMethods[] = {
{ "get_objref", get_objref, METH_VARARGS, "register a new object reference with the scheduler" },
{ "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" },
{ "alias_objrefs", alias_objrefs, METH_VARARGS, "make two objrefs refer to the same object" },
{ "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" },
{ "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" },
{ "submit_task", submit_task, METH_VARARGS, "call a remote function" },
{ "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" },
{ "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" },
{ "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" },
{ "task_info", task_info, METH_VARARGS, "get task statuses" },
{ "export_function", export_function, METH_VARARGS, "export function to workers" },
{ "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" },
{ "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" },
{ "kill_workers", kill_workers, METH_VARARGS, "kills all of the workers" },

View file

@ -291,6 +291,20 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
return Status::OK;
}
Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
auto workers = workers_.get();
for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->set_implementation(request->function().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) {
ImportFunctionReply import_reply;
(*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
}
}
return Status::OK;
}
void SchedulerService::deliver_object_async_if_necessary(ObjRef canonical_objref, ObjStoreId from, ObjStoreId to) {
bool object_present_or_in_transit;
{

View file

@ -70,6 +70,7 @@ public:
Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override;
Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override;
Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override;
Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override;
// This will ask an object store to send an object to another object store if
// the object is not already present in that object store and is not already

View file

@ -17,16 +17,32 @@ inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address)
}
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
task_ = std::unique_ptr<Task>(new Task(request->task())); // Copy task
RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
WorkerMessage message = { &task_ };
RAY_CHECK(send_queue_.send(&message), "error sending over IPC");
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->task = request->task();
{
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
}
message.release();
return Status::OK;
}
Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) {
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->function = request->function().implementation();
RAY_LOG(RAY_INFO, "importing function");
{
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
}
message.release();
return Status::OK;
}
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) {
WorkerMessage message = { NULL };
RAY_CHECK(send_queue_.send(&message), "error sending over IPC");
WorkerMessage* message_ptr = NULL;
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
return Status::OK;
}
@ -285,10 +301,10 @@ void Worker::register_function(const std::string& name, size_t num_return_vals)
scheduler_stub_->RegisterFunction(&context, request, &reply);
}
std::unique_ptr<Task> Worker::receive_next_task() {
WorkerMessage message;
RAY_CHECK(receive_queue_.receive(&message), "error receiving over IPC");
return message.task ? std::move(*message.task) : std::unique_ptr<Task>();
std::unique_ptr<WorkerMessage> Worker::receive_next_message() {
WorkerMessage* message_ptr;
RAY_CHECK(receive_queue_.receive(&message_ptr), "error receiving over IPC");
return std::unique_ptr<WorkerMessage>(message_ptr);
}
void Worker::notify_task_completed(bool task_succeeded, std::string error_message) {
@ -322,6 +338,16 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf
scheduler_stub_->TaskInfo(&context, request, &reply);
}
bool Worker::export_function(const std::string& function) {
RAY_CHECK(connected_, "Attempted to export function but failed.");
ClientContext context;
ExportFunctionRequest request;
request.mutable_function()->set_implementation(function);
ExportFunctionReply reply;
Status status = scheduler_stub_->ExportFunction(&context, request, &reply);
return true;
}
// Communication between the WorkerServer and the Worker happens via a message
// queue. This is because the Python interpreter needs to be single threaded
// (in our case running in the main thread), whereas the WorkerService will

View file

@ -24,19 +24,19 @@ using grpc::ClientContext;
using grpc::ClientWriter;
struct WorkerMessage {
std::unique_ptr<Task>* task;
Task task;
std::string function;
};
static_assert(std::is_pod<WorkerMessage>::value, "WorkerMessage must be memcpy-able");
class WorkerServiceImpl final : public WorkerService::Service {
public:
WorkerServiceImpl(const std::string& worker_address);
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override;
Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override;
Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override;
private:
std::string worker_address_;
std::unique_ptr<Task> task_; // copy of the current task
MessageQueue<WorkerMessage> send_queue_;
MessageQueue<WorkerMessage*> send_queue_;
};
class Worker {
@ -79,7 +79,7 @@ class Worker {
// it in the message queue, which is read by the Python interpreter
void start_worker_service();
// wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down.
std::unique_ptr<Task> receive_next_task();
std::unique_ptr<WorkerMessage> receive_next_message();
// tell the scheduler that we are done with the current task and request the
// next one, if task_succeeded is false, this tells the scheduler that the
// task threw an exception
@ -92,13 +92,15 @@ class Worker {
void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply);
// get task statuses from scheduler
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
// export function to workers
bool export_function(const std::string& function);
private:
bool connected_;
const size_t CHUNK_SIZE = 8 * 1024;
std::unique_ptr<Scheduler::Stub> scheduler_stub_;
std::thread worker_server_thread_;
MessageQueue<WorkerMessage> receive_queue_;
MessageQueue<WorkerMessage*> receive_queue_;
bip::managed_shared_memory segment_;
WorkerId workerid_;
ObjStoreId objstoreid_;

View file

@ -272,6 +272,41 @@ class APITest(unittest.TestCase):
ray.services.cleanup()
def testDefiningRemoteFunctions(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.SCRIPT_MODE)
# Test that we can define a remote function in the shell.
@ray.remote([int], [int])
def f(x):
return x + 1
self.assertEqual(ray.get(f(0)), 1)
# Test that we can redefine the remote function.
@ray.remote([int], [int])
def f(x):
return x + 10
self.assertEqual(ray.get(f(0)), 10)
# Test that we can close over plain old data.
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}]
@ray.remote([], [list])
def g():
return data
ray.get(g())
# Test that we can close over modules.
@ray.remote([], [np.ndarray])
def h():
return np.zeros([3, 5])
self.assertTrue(np.alltrue(ray.get(h()) == np.zeros([3, 5])))
@ray.remote([], [float])
def j():
return time.time()
ray.get(j())
ray.services.cleanup()
class TaskStatusTest(unittest.TestCase):
def testFailedTask(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")