From f2c43bec871a1fc0b827cd9081a1d7c465d08977 Mon Sep 17 00:00:00 2001 From: mehrdadn Date: Sun, 17 Jul 2016 22:05:07 -0700 Subject: [PATCH] Function serialization (#261) --- doc/install-on-macosx.md | 2 +- doc/install-on-ubuntu.md | 2 +- install-dependencies.sh | 4 +-- lib/python/ray/pickling.py | 7 ++++++ lib/python/ray/serialization.py | 2 +- lib/python/ray/worker.py | 32 ++++++++++++++++++------ protos/graph.proto | 2 +- protos/ray.proto | 17 +++++++++++++ protos/types.proto | 5 ++++ src/raylib.cc | 42 ++++++++++++++++++++----------- src/scheduler.cc | 14 +++++++++++ src/scheduler.h | 1 + src/worker.cc | 44 ++++++++++++++++++++++++++------- src/worker.h | 14 ++++++----- test/runtest.py | 35 ++++++++++++++++++++++++++ 15 files changed, 180 insertions(+), 43 deletions(-) create mode 100644 lib/python/ray/pickling.py diff --git a/doc/install-on-macosx.md b/doc/install-on-macosx.md index 132e048be..29daca152 100644 --- a/doc/install-on-macosx.md +++ b/doc/install-on-macosx.md @@ -19,7 +19,7 @@ brew update brew install git cmake automake autoconf libtool boost graphviz sudo easy_install pip sudo pip install ipython --user -sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz --ignore-installed six +sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle --ignore-installed six ``` ### Build diff --git a/doc/install-on-ubuntu.md b/doc/install-on-ubuntu.md index f0d2cd00d..f14cbebda 100644 --- a/doc/install-on-ubuntu.md +++ b/doc/install-on-ubuntu.md @@ -15,7 +15,7 @@ First install the dependencies. We currently do not support Python 3. ``` sudo apt-get update sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz -sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz +sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle ``` ### Build diff --git a/install-dependencies.sh b/install-dependencies.sh index 63a69b285..fbfb4052c 100755 --- a/install-dependencies.sh +++ b/install-dependencies.sh @@ -31,11 +31,11 @@ if [[ $platform == "linux" ]]; then # These commands must be kept in sync with the installation instructions. sudo apt-get update sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz - sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz + sudo pip install ipython typing funcsigs subprocess32 protobuf==3.0.0a2 colorama graphviz cloudpickle elif [[ $platform == "macosx" ]]; then # These commands must be kept in sync with the installation instructions. brew install git cmake automake autoconf libtool boost graphviz sudo easy_install pip sudo pip install ipython --user - sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0-alpha-2 colorama graphviz --ignore-installed six + sudo pip install numpy typing funcsigs subprocess32 protobuf==3.0.0-alpha-2 colorama graphviz cloudpickle --ignore-installed six fi diff --git a/lib/python/ray/pickling.py b/lib/python/ray/pickling.py new file mode 100644 index 000000000..d9e34b0ff --- /dev/null +++ b/lib/python/ray/pickling.py @@ -0,0 +1,7 @@ +import cloudpickle + +def dumps(func, arg_types, return_types): + return cloudpickle.dumps((func, arg_types, return_types)) + +def loads(function): + return cloudpickle.loads(function) diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index 782be9ac0..e67b418b9 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -67,6 +67,6 @@ def serialize_task(worker_capsule, func_name, args): return ray.lib.serialize_task(worker_capsule, func_name, primitive_args) def deserialize_task(worker_capsule, task): - func_name, primitive_args, return_objrefs = ray.lib.deserialize_task(worker_capsule, task) + func_name, primitive_args, return_objrefs = task args = [(arg if isinstance(arg, ray.lib.ObjRef) else from_primitive(arg)) for arg in primitive_args] return func_name, args, return_objrefs diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index c9bc515c9..cb0e34dff 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -10,6 +10,7 @@ import numpy as np import colorama import ray +import pickling import serialization import ray.internal.graph_pb2 import ray.graph @@ -120,6 +121,7 @@ class Worker(object): """Initialize a Worker object.""" self.functions = {} self.handle = None + self.mode = None def set_mode(self, mode): """Set the mode of the worker. @@ -510,13 +512,26 @@ def main_loop(worker=global_worker): store_outputs_in_objstore(return_objrefs, outputs, worker) # store output in local object store ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully while True: - task = ray.lib.wait_for_next_task(worker.handle) - if task is None: - # We use this as a mechanism to allow the scheduler to kill workers. When - # the scheduler wants to kill a worker, it gives the worker a null task, - # causing the worker program to exit the main loop here. - break - process_task(task) + (task, function) = ray.lib.wait_for_next_message(worker.handle) + try: + # Currently the schedule does not ask the worker to execute a task and + # import a function at the same time. + assert task is None or function is None + if task is None and function is None: + # We use this as a mechanism to allow the scheduler to kill workers. When + # the scheduler wants to kill a worker, it gives the worker a null task, + # causing the worker program to exit the main loop here. + break + if function is not None: + (function, arg_types, return_types) = pickling.loads(function) + if function.__module__ is None: function.__module__ = "__main__" + worker.register_function(remote(arg_types, return_types, worker)(function)) + if task is not None: + process_task(task) + finally: + # Allow releasing the variables BEFORE we wait for the next message or exit the block + del task + del function def remote(arg_types, return_types, worker=global_worker): """This decorator is used to create remote functions. @@ -526,6 +541,7 @@ def remote(arg_types, return_types, worker=global_worker): return_types (List[type]): List of Python types of the return values. """ def remote_decorator(func): + to_export = pickling.dumps(func, arg_types, return_types) if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE] else None def func_executor(arguments): """This gets run when the remote function is executed.""" logging.info("Calling function {}".format(func.__name__)) @@ -560,6 +576,8 @@ def remote(arg_types, return_types, worker=global_worker): func_call.has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in func_call.sig_params]) func_call.has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in func_call.sig_params]) check_signature_supported(func_call) + if to_export is not None: + ray.lib.export_function(worker.handle, to_export) return func_call return remote_decorator diff --git a/protos/graph.proto b/protos/graph.proto index 98dabeddd..01badf03f 100644 --- a/protos/graph.proto +++ b/protos/graph.proto @@ -3,7 +3,7 @@ syntax = "proto3"; import "types.proto"; message Task { - string name = 1; // Name of the function call + string name = 1; // Name of the function call. Must not be empty. repeated Value arg = 2; // List of arguments, can be either object references or protobuf descriptions of object passed by value repeated uint64 result = 3; // Object references for result } diff --git a/protos/ray.proto b/protos/ray.proto index 8bd6a5a55..ae8818af3 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -52,6 +52,8 @@ service Scheduler { rpc TaskInfo(TaskInfoRequest) returns (TaskInfoReply); // Kills the workers rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply); + // Exports function to the workers + rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply); } message AckReply { @@ -228,6 +230,13 @@ message KillWorkersReply { bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks } +message ExportFunctionRequest { + Function function = 1; +} + +message ExportFunctionReply { +} + // These messages are for getting information about the object store state message ObjStoreInfoRequest { @@ -243,6 +252,7 @@ message ObjStoreInfoReply { service WorkerService { rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker + rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker rpc Die(DieRequest) returns (DieReply); // Kills this worker } @@ -253,6 +263,13 @@ message ExecuteTaskRequest { message ExecuteTaskReply { } +message ImportFunctionRequest { + Function function = 1; +} + +message ImportFunctionReply { +} + message DieRequest { } diff --git a/protos/types.proto b/protos/types.proto index cbd87d82e..83ae413a7 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -32,6 +32,11 @@ message PyObj { bytes data = 1; } +// Used for shipping remote functions to workers +message Function { + bytes implementation = 1; +} + // Union of possible object types message Obj { String string_data = 1; diff --git a/src/raylib.cc b/src/raylib.cc index d0f9fc4cd..b84478d35 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -592,12 +592,7 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { return PyCapsule_New(static_cast(task), "task", &TaskCapsule_Destructor); } -static PyObject* deserialize_task(PyObject* self, PyObject* args) { - PyObject* worker_capsule; - Task* task; - if (!PyArg_ParseTuple(args, "OO&", &worker_capsule, &PyObjectToTask, &task)) { - return NULL; - } +static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) { std::vector objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value. PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size()); int argsize = task->arg_size(); @@ -667,19 +662,36 @@ static PyObject* connected(PyObject* self, PyObject* args) { Py_RETURN_FALSE; } -static PyObject* wait_for_next_task(PyObject* self, PyObject* args) { - Worker* worker; - if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) { +static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { + PyObject* worker_capsule; + if (!PyArg_ParseTuple(args, "O", &worker_capsule)) { return NULL; } - if (std::unique_ptr task = worker->receive_next_task()) { - PyObject* pyobj = PyCapsule_New(task.get(), "task", TaskCapsule_Destructor); - task.release(); // Now that the wrapper object was constructed successfully, release ownership - return pyobj; + Worker* worker; + PyObjectToWorker(worker_capsule, &worker); + if (std::unique_ptr message = worker->receive_next_message()) { + PyObject* t = PyTuple_New(2); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple. + PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_None : deserialize_task(worker_capsule, &message->task)); + PyTuple_SetItem(t, 1, message->function.empty() ? Py_None : PyString_FromStringAndSize(message->function.data(), static_cast(message->function.size()))); + return t; } Py_RETURN_NONE; } +static PyObject* export_function(PyObject* self, PyObject* args) { + Worker* worker; + const char* function; + int function_size; + if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) { + return NULL; + } + if (worker->export_function(std::string(function, static_cast(function_size)))) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + static PyObject* submit_task(PyObject* self, PyObject* args) { PyObject* worker_capsule; Task* task; @@ -918,7 +930,6 @@ static PyMethodDef RayLibMethods[] = { { "is_arrow", is_arrow, METH_VARARGS, "is the object in the local object store an arrow object?"}, { "unmap_object", unmap_object, METH_VARARGS, "unmap the object from the client's shared memory pool"}, { "serialize_task", serialize_task, METH_VARARGS, "serialize a task to protocol buffers" }, - { "deserialize_task", deserialize_task, METH_VARARGS, "deserialize a task from protocol buffers" }, { "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" }, { "disconnect", disconnect, METH_VARARGS, "disconnect the worker from the scheduler and the object store" }, { "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" }, @@ -928,12 +939,13 @@ static PyMethodDef RayLibMethods[] = { { "get_objref", get_objref, METH_VARARGS, "register a new object reference with the scheduler" }, { "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" }, { "alias_objrefs", alias_objrefs, METH_VARARGS, "make two objrefs refer to the same object" }, - { "wait_for_next_task", wait_for_next_task, METH_VARARGS, "get next task from scheduler (blocking)" }, + { "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" }, { "submit_task", submit_task, METH_VARARGS, "call a remote function" }, { "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" }, { "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" }, { "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" }, { "task_info", task_info, METH_VARARGS, "get task statuses" }, + { "export_function", export_function, METH_VARARGS, "export function to workers" }, { "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" }, { "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" }, { "kill_workers", kill_workers, METH_VARARGS, "kills all of the workers" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index 51fe4fc33..9ff67eec7 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -291,6 +291,20 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe return Status::OK; } +Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) { + auto workers = workers_.get(); + for (size_t i = 0; i < workers->size(); ++i) { + ClientContext import_context; + ImportFunctionRequest import_request; + import_request.mutable_function()->set_implementation(request->function().implementation()); + if ((*workers)[i].current_task != ROOT_OPERATION) { + ImportFunctionReply import_reply; + (*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply); + } + } + return Status::OK; +} + void SchedulerService::deliver_object_async_if_necessary(ObjRef canonical_objref, ObjStoreId from, ObjStoreId to) { bool object_present_or_in_transit; { diff --git a/src/scheduler.h b/src/scheduler.h index f0289d783..3f3ebe1d1 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -70,6 +70,7 @@ public: Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override; Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override; Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override; + Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override; // This will ask an object store to send an object to another object store if // the object is not already present in that object store and is not already diff --git a/src/worker.cc b/src/worker.cc index 0ed8a4982..a2ecb601c 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -17,16 +17,32 @@ inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address) } Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) { - task_ = std::unique_ptr(new Task(request->task())); // Copy task RAY_LOG(RAY_INFO, "invoked task " << request->task().name()); - WorkerMessage message = { &task_ }; - RAY_CHECK(send_queue_.send(&message), "error sending over IPC"); + std::unique_ptr message(new WorkerMessage()); + message->task = request->task(); + { + WorkerMessage* message_ptr = message.get(); + RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); + } + message.release(); + return Status::OK; +} + +Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) { + std::unique_ptr message(new WorkerMessage()); + message->function = request->function().implementation(); + RAY_LOG(RAY_INFO, "importing function"); + { + WorkerMessage* message_ptr = message.get(); + RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); + } + message.release(); return Status::OK; } Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) { - WorkerMessage message = { NULL }; - RAY_CHECK(send_queue_.send(&message), "error sending over IPC"); + WorkerMessage* message_ptr = NULL; + RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); return Status::OK; } @@ -285,10 +301,10 @@ void Worker::register_function(const std::string& name, size_t num_return_vals) scheduler_stub_->RegisterFunction(&context, request, &reply); } -std::unique_ptr Worker::receive_next_task() { - WorkerMessage message; - RAY_CHECK(receive_queue_.receive(&message), "error receiving over IPC"); - return message.task ? std::move(*message.task) : std::unique_ptr(); +std::unique_ptr Worker::receive_next_message() { + WorkerMessage* message_ptr; + RAY_CHECK(receive_queue_.receive(&message_ptr), "error receiving over IPC"); + return std::unique_ptr(message_ptr); } void Worker::notify_task_completed(bool task_succeeded, std::string error_message) { @@ -322,6 +338,16 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf scheduler_stub_->TaskInfo(&context, request, &reply); } +bool Worker::export_function(const std::string& function) { + RAY_CHECK(connected_, "Attempted to export function but failed."); + ClientContext context; + ExportFunctionRequest request; + request.mutable_function()->set_implementation(function); + ExportFunctionReply reply; + Status status = scheduler_stub_->ExportFunction(&context, request, &reply); + return true; +} + // Communication between the WorkerServer and the Worker happens via a message // queue. This is because the Python interpreter needs to be single threaded // (in our case running in the main thread), whereas the WorkerService will diff --git a/src/worker.h b/src/worker.h index 313f67d88..7a26da407 100644 --- a/src/worker.h +++ b/src/worker.h @@ -24,19 +24,19 @@ using grpc::ClientContext; using grpc::ClientWriter; struct WorkerMessage { - std::unique_ptr* task; + Task task; + std::string function; }; -static_assert(std::is_pod::value, "WorkerMessage must be memcpy-able"); class WorkerServiceImpl final : public WorkerService::Service { public: WorkerServiceImpl(const std::string& worker_address); Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override; + Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override; Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override; private: std::string worker_address_; - std::unique_ptr task_; // copy of the current task - MessageQueue send_queue_; + MessageQueue send_queue_; }; class Worker { @@ -79,7 +79,7 @@ class Worker { // it in the message queue, which is read by the Python interpreter void start_worker_service(); // wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down. - std::unique_ptr receive_next_task(); + std::unique_ptr receive_next_message(); // tell the scheduler that we are done with the current task and request the // next one, if task_succeeded is false, this tells the scheduler that the // task threw an exception @@ -92,13 +92,15 @@ class Worker { void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply); // get task statuses from scheduler void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply); + // export function to workers + bool export_function(const std::string& function); private: bool connected_; const size_t CHUNK_SIZE = 8 * 1024; std::unique_ptr scheduler_stub_; std::thread worker_server_thread_; - MessageQueue receive_queue_; + MessageQueue receive_queue_; bip::managed_shared_memory segment_; WorkerId workerid_; ObjStoreId objstoreid_; diff --git a/test/runtest.py b/test/runtest.py index 89612e19a..4ad01a845 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -272,6 +272,41 @@ class APITest(unittest.TestCase): ray.services.cleanup() + def testDefiningRemoteFunctions(self): + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py") + ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.SCRIPT_MODE) + + # Test that we can define a remote function in the shell. + @ray.remote([int], [int]) + def f(x): + return x + 1 + self.assertEqual(ray.get(f(0)), 1) + + # Test that we can redefine the remote function. + @ray.remote([int], [int]) + def f(x): + return x + 10 + self.assertEqual(ray.get(f(0)), 10) + + # Test that we can close over plain old data. + data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}] + @ray.remote([], [list]) + def g(): + return data + ray.get(g()) + + # Test that we can close over modules. + @ray.remote([], [np.ndarray]) + def h(): + return np.zeros([3, 5]) + self.assertTrue(np.alltrue(ray.get(h()) == np.zeros([3, 5]))) + @ray.remote([], [float]) + def j(): + return time.time() + ray.get(j()) + + ray.services.cleanup() + class TaskStatusTest(unittest.TestCase): def testFailedTask(self): worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")