From ab3448a9b41f2d013059848f2baa2aa1240eda9c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 9 Jan 2017 20:15:54 -0800 Subject: [PATCH] Plasma Optimizations (#190) * bypass python when storing objects into the object store * clang-format * Bug fixes. * fix include paths * Fixes. * fix bug * clang-format * fix * fix release after disconnect --- .travis.yml | 8 + lib/python/ray/worker.py | 34 +--- numbuf/CMakeLists.txt | 23 ++- numbuf/build.sh | 2 +- numbuf/python/src/pynumbuf/numbuf.cc | 260 +++++++++++++++++++++++---- numbuf/python/test/runtest.py | 2 +- src/plasma/Makefile | 2 +- src/plasma/plasma_extension.c | 30 +--- src/plasma/plasma_extension.h | 25 +++ 9 files changed, 289 insertions(+), 97 deletions(-) create mode 100644 src/plasma/plasma_extension.h diff --git a/.travis.yml b/.travis.yml index 7ac70c534..4e214a81d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -76,3 +76,11 @@ script: - python test/failure_test.py - python test/microbenchmarks.py - python test/stress_tests.py + +after_script: + # Make sure that we can build numbuf as a standalone library. + - pip uninstall -y numbuf + - cd numbuf/build + - rm -rf * + - cmake -DHAS_PLASMA=OFF -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_FLAGS="-g" -DCMAKE_CXX_FLAGS="-g" .. + - make install diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index e04d7aa78..bb8a1bf95 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -333,18 +333,6 @@ class RayReusables(object): """ raise Exception("Attempted deletion of attribute {}. Attributes of a RayReusable object may not be deleted.".format(name)) -class ObjectFixture(object): - """This is used to handle releasing objects backed by the object store. - - This keeps a PlasmaBuffer in scope as long as an object that is backed by that - PlasmaBuffer is in scope. This prevents memory in the object store from getting - released while it is still being used to back a Python object. - """ - - def __init__(self, plasma_buffer): - """Initialize an ObjectFixture object.""" - self.plasma_buffer = plasma_buffer - class Worker(object): """A class used to define the control flow of a worker process. @@ -422,10 +410,8 @@ class Worker(object): value (serializable object): The value to put in the object store. """ # Serialize and put the object in the object store. - schema, size, serialized = numbuf_serialize(value) - size = size + 4096 * 4 + 8 # The last 8 bytes are for the metadata offset. This is temporary. try: - buff = self.plasma_client.create(objectid.id(), size, bytearray(schema)) + numbuf.store_list(objectid.id(), self.plasma_client.conn, [value]) except plasma.plasma_object_exists_error as e: # The object already exists in the object store, so there is no need to # add it again. TODO(rkn): We need to compare the hashes and make sure @@ -433,11 +419,6 @@ class Worker(object): # code to the caller instead of printing a message. print("This object already exists in the object store.") return - data = np.frombuffer(buff.buffer, dtype="byte")[8:] - metadata_offset = numbuf.write_to_buffer(serialized, memoryview(data)) - np.frombuffer(buff.buffer, dtype="int64", count=1)[0] = metadata_offset - self.plasma_client.seal(objectid.id()) - global contained_objectids # Optionally do something with the contained_objectids here. contained_objectids = [] @@ -452,18 +433,7 @@ class Worker(object): objectid (object_id.ObjectID): The object ID of the value to retrieve. """ self.plasma_client.fetch([objectid.id()]) - buff = self.plasma_client.get(objectid.id()) - metadata_buff = self.plasma_client.get_metadata(objectid.id()) - metadata_size = len(metadata_buff) - data = np.frombuffer(buff.buffer, dtype="byte")[8:] - metadata = np.frombuffer(metadata_buff.buffer, dtype="byte") - metadata_offset = int(np.frombuffer(buff.buffer, dtype="int64", count=1)[0]) - serialized = numbuf.read_from_buffer(memoryview(data), memoryview(metadata), metadata_offset) - # Create an ObjectFixture. If the object we are getting is backed by the - # PlasmaBuffer, this ObjectFixture will keep the PlasmaBuffer in scope as - # long as the object is in scope. - object_fixture = ObjectFixture(buff) - deserialized = numbuf.deserialize_list(serialized, object_fixture) + deserialized = numbuf.retrieve_list(objectid.id(), self.plasma_client.conn) # Unwrap the object from the list (it was wrapped put_object). assert len(deserialized) == 1 return deserialized[0] diff --git a/numbuf/CMakeLists.txt b/numbuf/CMakeLists.txt index 0ee519c18..67ebe4ec1 100644 --- a/numbuf/CMakeLists.txt +++ b/numbuf/CMakeLists.txt @@ -4,9 +4,13 @@ project(numbuf) list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) -# Make libnumbuf.so look for shared libraries in the folder libnumbuf.so is in. -set(CMAKE_INSTALL_RPATH "$ORIGIN/") -set(CMAKE_MACOSX_RPATH 1) +option(HAS_PLASMA + "Are we linking with the plasma object store? Recommended if numbuf is used as part of ray." + OFF) + +if(HAS_PLASMA) + add_definitions(-DHAS_PLASMA) +endif() message(STATUS "Trying custom approach for finding Python.") # Start off by figuring out which Python executable to use. @@ -90,6 +94,13 @@ include_directories("${ARROW_DIR}/cpp/src/") include_directories("cpp/src/") include_directories("python/src/") +if(HAS_PLASMA) + include_directories("../src/plasma") + include_directories("../src/common") + include_directories("../src/common/thirdparty") + include_directories("../src/common/build/flatcc-prefix/src/flatcc/include") +endif() + add_definitions(-fPIC) add_library(numbuf SHARED @@ -112,4 +123,10 @@ else() target_link_libraries(numbuf -Wl,--whole-archive ${ARROW_LIB} -Wl,--no-whole-archive ${ARROW_IO_LIB} ${ARROW_IPC_LIB} ${PYTHON_LIBRARIES}) endif() +if(HAS_PLASMA) + target_link_libraries(numbuf ${ARROW_LIB} ${ARROW_IO_LIB} ${ARROW_IPC_LIB} ${PYTHON_LIBRARIES} "${CMAKE_SOURCE_DIR}/../src/plasma/build/libplasma_client.a" "${CMAKE_SOURCE_DIR}/../src/common/build/libcommon.a" "${CMAKE_SOURCE_DIR}/../src/common/build/flatcc-prefix/src/flatcc/lib/libflatcc.a") +else() + target_link_libraries(numbuf ${ARROW_LIB} ${ARROW_IO_LIB} ${ARROW_IPC_LIB} ${PYTHON_LIBRARIES}) +endif() + install(TARGETS numbuf DESTINATION ${CMAKE_SOURCE_DIR}/numbuf/) diff --git a/numbuf/build.sh b/numbuf/build.sh index 5f3401801..0a47172d2 100755 --- a/numbuf/build.sh +++ b/numbuf/build.sh @@ -18,6 +18,6 @@ fi mkdir -p "$ROOT_DIR/build" pushd "$ROOT_DIR/build" - cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_FLAGS="-g" -DCMAKE_CXX_FLAGS="-g" .. + cmake -DHAS_PLASMA=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_C_FLAGS="-g" -DCMAKE_CXX_FLAGS="-g" .. make install -j$PARALLEL popd diff --git a/numbuf/python/src/pynumbuf/numbuf.cc b/numbuf/python/src/pynumbuf/numbuf.cc index c1d509ba3..fc6aaf5b6 100644 --- a/numbuf/python/src/pynumbuf/numbuf.cc +++ b/numbuf/python/src/pynumbuf/numbuf.cc @@ -14,17 +14,59 @@ #include "adapters/python.h" #include "memory.h" +#ifdef HAS_PLASMA +extern "C" { +#include "format/plasma_reader.h" +#include "plasma_client.h" +} + +PyObject* NumbufPlasmaOutOfMemoryError; +PyObject* NumbufPlasmaObjectExistsError; +#endif + using namespace arrow; using namespace numbuf; -std::shared_ptr make_row_batch(std::shared_ptr data) { +int64_t make_schema_and_batch(std::shared_ptr data, + std::shared_ptr* metadata_out, std::shared_ptr* batch_out) { auto field = std::make_shared("list", data->type()); std::shared_ptr schema(new Schema({field})); - return std::shared_ptr(new RecordBatch(schema, data->length(), {data})); + *batch_out = + std::shared_ptr(new RecordBatch(schema, data->length(), {data})); + int64_t size = 0; + ARROW_CHECK_OK(ipc::GetRecordBatchSize(batch_out->get(), &size)); + ARROW_CHECK_OK(ipc::WriteSchema((*batch_out)->schema().get(), metadata_out)); + return size; +} + +Status read_batch(std::shared_ptr schema_buffer, int64_t header_end_offset, + uint8_t* data, int64_t size, std::shared_ptr* batch_out) { + std::shared_ptr message; + RETURN_NOT_OK(ipc::Message::Open(schema_buffer, &message)); + DCHECK_EQ(ipc::Message::SCHEMA, message->type()); + std::shared_ptr schema_msg = message->GetSchema(); + std::shared_ptr schema; + RETURN_NOT_OK(schema_msg->GetSchema(&schema)); + auto source = std::make_shared(data, size); + std::shared_ptr reader; + RETURN_NOT_OK(ipc::RecordBatchReader::Open(source.get(), header_end_offset, &reader)); + RETURN_NOT_OK(reader->GetRecordBatch(schema, batch_out)); + return Status::OK(); } extern "C" { +#define CHECK_SERIALIZATION_ERROR(STATUS) \ + do { \ + Status _s = (STATUS); \ + if (!_s.ok()) { \ + /* If this condition is true, there was an error in the callback that \ + * needs to be passed through */ \ + if (!PyErr_Occurred()) { PyErr_SetString(NumbufError, _s.ToString().c_str()); } \ + return NULL; \ + } \ + } while (0) + static PyObject* NumbufError; PyObject* numbuf_serialize_callback = NULL; @@ -55,25 +97,15 @@ static PyObject* serialize_list(PyObject* self, PyObject* args) { int32_t recursion_depth = 0; Status s = SerializeSequences(std::vector({value}), recursion_depth, &array); - if (!s.ok()) { - // If this condition is true, there was an error in the callback that - // needs to be passed through - if (!PyErr_Occurred()) { PyErr_SetString(NumbufError, s.ToString().c_str()); } - return NULL; - } + CHECK_SERIALIZATION_ERROR(s); auto batch = new std::shared_ptr(); - *batch = make_row_batch(array); - - int64_t size = 0; - ARROW_CHECK_OK(arrow::ipc::GetRecordBatchSize(batch->get(), &size)); - - std::shared_ptr buffer; - ARROW_CHECK_OK(ipc::WriteSchema((*batch)->schema().get(), &buffer)); - auto ptr = reinterpret_cast(buffer->data()); + std::shared_ptr metadata; + int64_t size = make_schema_and_batch(array, &metadata, batch); + auto ptr = reinterpret_cast(metadata->data()); PyObject* r = PyTuple_New(3); - PyTuple_SetItem(r, 0, PyByteArray_FromStringAndSize(ptr, buffer->size())); + PyTuple_SetItem(r, 0, PyByteArray_FromStringAndSize(ptr, metadata->size())); PyTuple_SetItem(r, 1, PyLong_FromLong(size)); PyTuple_SetItem(r, 2, PyCapsule_New(reinterpret_cast(batch), "arrow", &ArrowCapsule_Destructor)); @@ -104,30 +136,20 @@ static PyObject* write_to_buffer(PyObject* self, PyObject* args) { static PyObject* read_from_buffer(PyObject* self, PyObject* args) { PyObject* data_memoryview; PyObject* metadata_memoryview; - int64_t metadata_offset; + int64_t header_end_offset; if (!PyArg_ParseTuple( - args, "OOL", &data_memoryview, &metadata_memoryview, &metadata_offset)) { + args, "OOL", &data_memoryview, &metadata_memoryview, &header_end_offset)) { return NULL; } Py_buffer* metadata_buffer = PyMemoryView_GET_BUFFER(metadata_memoryview); + Py_buffer* data_buffer = PyMemoryView_GET_BUFFER(data_memoryview); auto ptr = reinterpret_cast(metadata_buffer->buf); auto schema_buffer = std::make_shared(ptr, metadata_buffer->len); - std::shared_ptr message; - ARROW_CHECK_OK(ipc::Message::Open(schema_buffer, &message)); - DCHECK_EQ(ipc::Message::SCHEMA, message->type()); - std::shared_ptr schema_msg = message->GetSchema(); - std::shared_ptr schema; - ARROW_CHECK_OK(schema_msg->GetSchema(&schema)); - Py_buffer* buffer = PyMemoryView_GET_BUFFER(data_memoryview); - auto source = std::make_shared( - reinterpret_cast(buffer->buf), buffer->len); - std::shared_ptr reader; - ARROW_CHECK_OK( - arrow::ipc::RecordBatchReader::Open(source.get(), metadata_offset, &reader)); auto batch = new std::shared_ptr(); - ARROW_CHECK_OK(reader->GetRecordBatch(schema, batch)); + ARROW_CHECK_OK(read_batch(schema_buffer, header_end_offset, + reinterpret_cast(data_buffer->buf), data_buffer->len, batch)); return PyCapsule_New(reinterpret_cast(batch), "arrow", &ArrowCapsule_Destructor); } @@ -139,12 +161,7 @@ static PyObject* deserialize_list(PyObject* self, PyObject* args) { if (!PyArg_ParseTuple(args, "O&|O", &PyObjectToArrow, &data, &base)) { return NULL; } PyObject* result; Status s = DeserializeList((*data)->column(0), 0, (*data)->num_rows(), base, &result); - if (!s.ok()) { - // If this condition is true, there was an error in the callback that - // needs to be passed through - if (!PyErr_Occurred()) { PyErr_SetString(NumbufError, s.ToString().c_str()); } - return NULL; - } + CHECK_SERIALIZATION_ERROR(s); return result; } @@ -174,6 +191,152 @@ static PyObject* register_callbacks(PyObject* self, PyObject* args) { return result; } +#ifdef HAS_PLASMA + +#include "plasma_extension.h" + +/** + * Release the object when its associated PyCapsule goes out of scope. + * + * The PyCapsule is used as the base object for the Python object that + * is stored with store_list and retrieved with retrieve_list. The base + * object ensures that the reference count of the capsule is non-zero + * during the lifetime of the Python object returned by retrieve_list. + * + * @param capsule The capsule that went out of scope. + * @return Void. + */ +static void BufferCapsule_Destructor(PyObject* capsule) { + object_id* id = reinterpret_cast(PyCapsule_GetPointer(capsule, "buffer")); + auto context = reinterpret_cast(PyCapsule_GetContext(capsule)); + /* We use the context of the connection capsule to indicate if the connection + * is still active (if the context is NULL) or if it is closed (if the context + * is (void*) 0x1). This is neccessary because the primary pointer of the + * capsule cannot be NULL. */ + if (PyCapsule_GetContext(context) == NULL) { + plasma_connection* conn; + CHECK(PyObjectToPlasmaConnection(context, &conn)); + plasma_release(conn, *id); + } + Py_XDECREF(context); + delete id; +} + +/** + * Store a PyList in the plasma store. + * + * This function converts the PyList into an arrow RecordBatch, constructs the + * metadata (schema) of the PyList, creates a new plasma object, puts the data + * into the plasma buffer and the schema into the plasma metadata. This raises + * + * + * @param args Contains the object ID the list is stored under, the + * connection to the plasma store and the PyList we want to store. + * @return None. + */ +static PyObject* store_list(PyObject* self, PyObject* args) { + object_id obj_id; + plasma_connection* conn; + PyObject* value; + if (!PyArg_ParseTuple(args, "O&O&O", PyStringToUniqueID, &obj_id, + PyObjectToPlasmaConnection, &conn, &value)) { + return NULL; + } + if (!PyList_Check(value)) { return NULL; } + + std::shared_ptr array; + int32_t recursion_depth = 0; + Status s = SerializeSequences(std::vector({value}), recursion_depth, &array); + CHECK_SERIALIZATION_ERROR(s); + + std::shared_ptr batch; + std::shared_ptr metadata; + int64_t size = make_schema_and_batch(array, &metadata, &batch); + + uint8_t* data; + /* The arrow schema is stored as the metadata of the plasma object and + * both the arrow data and the header end offset are + * stored in the plasma data buffer. The header end offset is stored in + * the first sizeof(int64_t) bytes of the data buffer. The RecordBatch + * data is stored after that. */ + int error_code = plasma_create(conn, obj_id, sizeof(size) + size, + (uint8_t*)metadata->data(), metadata->size(), &data); + if (error_code == PlasmaError_ObjectExists) { + PyErr_SetString(NumbufPlasmaObjectExistsError, + "An object with this ID already exists in the plasma " + "store."); + return NULL; + } + if (error_code == PlasmaError_OutOfMemory) { + PyErr_SetString(NumbufPlasmaOutOfMemoryError, + "The plasma store ran out of memory and could not create " + "this object."); + return NULL; + } + CHECK(error_code == PlasmaError_OK); + + auto target = std::make_shared(sizeof(size) + data, size); + int64_t body_end_offset; + int64_t header_end_offset; + ARROW_CHECK_OK(ipc::WriteRecordBatch(batch->columns(), batch->num_rows(), target.get(), + &body_end_offset, &header_end_offset)); + + /* Save the header end offset at the beginning of the plasma data buffer. */ + *((int64_t*)data) = header_end_offset; + /* Do the plasma_release corresponding to the call to plasma_create. */ + plasma_release(conn, obj_id); + /* Seal the object. */ + plasma_seal(conn, obj_id); + Py_RETURN_NONE; +} + +/** + * Retrieve a PyList from the plasma store. + * + * This reads the arrow schema from the plasma metadata, constructs + * Python objects from the plasma data according to the schema and + * returns the object. + * + * @param args Object ID of the PyList to be retrieved and connection to the + * plasma store. + * @return The PyList. + */ +static PyObject* retrieve_list(PyObject* self, PyObject* args) { + object_id obj_id; + PyObject* plasma_conn; + if (!PyArg_ParseTuple(args, "O&O", PyStringToUniqueID, &obj_id, &plasma_conn)) { + return NULL; + } + plasma_connection* conn; + if (!PyObjectToPlasmaConnection(plasma_conn, &conn)) { return NULL; } + object_id* buffer_obj_id = new object_id(obj_id); + /* This keeps a Plasma buffer in scope as long as an object that is backed by that + * buffer is in scope. This prevents memory in the object store from getting + * released while it is still being used to back a Python object. */ + PyObject* base = PyCapsule_New(buffer_obj_id, "buffer", BufferCapsule_Destructor); + PyCapsule_SetContext(base, plasma_conn); + Py_XINCREF(plasma_conn); + + int64_t size, metadata_size; + uint8_t *data, *metadata; + plasma_get(conn, obj_id, &size, &data, &metadata_size, &metadata); + + /* Remember: The metadata offset was written at the beginning of the plasma buffer. */ + int64_t header_end_offset = *((int64_t*)data); + auto schema_buffer = std::make_shared(metadata, metadata_size); + auto batch = std::shared_ptr(); + ARROW_CHECK_OK(read_batch(schema_buffer, header_end_offset, data + sizeof(size), + size - sizeof(size), &batch)); + + PyObject* result; + Status s = DeserializeList(batch->column(0), 0, batch->num_rows(), base, &result); + CHECK_SERIALIZATION_ERROR(s); + Py_XDECREF(base); + return result; +} + +#endif // HAS_PLASMA + static PyMethodDef NumbufMethods[] = { {"serialize_list", serialize_list, METH_VARARGS, "serialize a Python list"}, {"deserialize_list", deserialize_list, METH_VARARGS, "deserialize a Python list"}, @@ -182,6 +345,10 @@ static PyMethodDef NumbufMethods[] = { "read serialized data from buffer"}, {"register_callbacks", register_callbacks, METH_VARARGS, "set serialization and deserialization callbacks"}, +#ifdef HAS_PLASMA + {"store_list", store_list, METH_VARARGS, "store a Python list in plasma"}, + {"retrieve_list", retrieve_list, METH_VARARGS, "retrieve a Python list from plasma"}, +#endif {NULL, NULL, 0, NULL}}; // clang-format off @@ -224,6 +391,23 @@ MOD_INIT(libnumbuf) { Py_InitModule3("libnumbuf", NumbufMethods, "Python C Extension for Numbuf"); #endif +#if HAS_PLASMA + /* Create a custom exception for when an object ID is reused. */ + char numbuf_plasma_object_exists_error[] = "numbuf_plasma_object_exists.error"; + NumbufPlasmaObjectExistsError = + PyErr_NewException(numbuf_plasma_object_exists_error, NULL, NULL); + Py_INCREF(NumbufPlasmaObjectExistsError); + PyModule_AddObject( + m, "pnumbuf_lasma_object_exists_error", NumbufPlasmaObjectExistsError); + /* Create a custom exception for when the plasma store is out of memory. */ + char numbuf_plasma_out_of_memory_error[] = "numbuf_plasma_out_of_memory.error"; + NumbufPlasmaOutOfMemoryError = + PyErr_NewException(numbuf_plasma_out_of_memory_error, NULL, NULL); + Py_INCREF(NumbufPlasmaOutOfMemoryError); + PyModule_AddObject( + m, "numbuf_plasma_out_of_memory_error", NumbufPlasmaOutOfMemoryError); +#endif + char numbuf_error[] = "numbuf.error"; NumbufError = PyErr_NewException(numbuf_error, NULL, NULL); Py_INCREF(NumbufError); diff --git a/numbuf/python/test/runtest.py b/numbuf/python/test/runtest.py index 3ffd29247..82cacb999 100644 --- a/numbuf/python/test/runtest.py +++ b/numbuf/python/test/runtest.py @@ -111,7 +111,7 @@ class SerializationTests(unittest.TestCase): def testBuffer(self): for (i, obj) in enumerate(TEST_OBJECTS): schema, size, batch = numbuf.serialize_list([obj]) - size = size + 4096 # INITIAL_METADATA_SIZE in arrow + size = size + 4096 # INITIAL_METADATA_SIZE in arrow. buff = np.zeros(size, dtype="uint8") metadata_offset = numbuf.write_to_buffer(batch, memoryview(buff)) array = numbuf.read_from_buffer(memoryview(buff), memoryview(schema), metadata_offset) diff --git a/src/plasma/Makefile b/src/plasma/Makefile index 725df02a0..e01742a9d 100644 --- a/src/plasma/Makefile +++ b/src/plasma/Makefile @@ -1,6 +1,6 @@ CC = gcc # The -rdynamic is here so we always get function names in backtraces. -CFLAGS = -g -Wall -rdynamic -Wextra -Werror=implicit-function-declaration -Wno-sign-compare -Wno-unused-parameter -Wno-type-limits -Wno-missing-field-initializers --std=c99 -D_XOPEN_SOURCE=500 -D_POSIX_C_SOURCE=200809L -I. -Ithirdparty -I../common -I../common/thirdparty -I../common/build/flatcc-prefix/src/flatcc/include +CFLAGS = -g -Wall -rdynamic -Wextra -Werror=implicit-function-declaration -Wno-sign-compare -Wno-unused-parameter -Wno-type-limits -Wno-missing-field-initializers --std=c99 -D_XOPEN_SOURCE=500 -D_POSIX_C_SOURCE=200809L -I. -Ithirdparty -I../common -I../common/thirdparty -I../common/build/flatcc-prefix/src/flatcc/include -fPIC TEST_CFLAGS = -DPLASMA_TEST=1 -I. BUILD = build diff --git a/src/plasma/plasma_extension.c b/src/plasma/plasma_extension.c index 74800671f..bff7e6d9e 100644 --- a/src/plasma/plasma_extension.c +++ b/src/plasma/plasma_extension.c @@ -10,26 +10,7 @@ PyObject *PlasmaOutOfMemoryError; PyObject *PlasmaObjectExistsError; -static int PyObjectToPlasmaConnection(PyObject *object, - plasma_connection **conn) { - if (PyCapsule_IsValid(object, "plasma")) { - *conn = (plasma_connection *) PyCapsule_GetPointer(object, "plasma"); - return 1; - } else { - PyErr_SetString(PyExc_TypeError, "must be a 'plasma' capsule"); - return 0; - } -} - -static int PyStringToUniqueID(PyObject *object, object_id *object_id) { - if (PyBytes_Check(object)) { - memcpy(&object_id->id[0], PyBytes_AsString(object), UNIQUE_ID_SIZE); - return 1; - } else { - PyErr_SetString(PyExc_TypeError, "must be a 20 character string"); - return 0; - } -} +#include "plasma_extension.h" PyObject *PyPlasma_connect(PyObject *self, PyObject *args) { const char *store_socket_name; @@ -50,11 +31,18 @@ PyObject *PyPlasma_connect(PyObject *self, PyObject *args) { } PyObject *PyPlasma_disconnect(PyObject *self, PyObject *args) { + PyObject *conn_capsule; plasma_connection *conn; - if (!PyArg_ParseTuple(args, "O&", PyObjectToPlasmaConnection, &conn)) { + if (!PyArg_ParseTuple(args, "O", &conn_capsule)) { return NULL; } + CHECK(PyObjectToPlasmaConnection(conn_capsule, &conn)); plasma_disconnect(conn); + /* We use the context of the connection capsule to indicate if the connection + * is still active (if the context is NULL) or if it is closed (if the context + * is (void*) 0x1). This is neccessary because the primary pointer of the + * capsule cannot be NULL. */ + PyCapsule_SetContext(conn_capsule, (void *) 0x1); Py_RETURN_NONE; } diff --git a/src/plasma/plasma_extension.h b/src/plasma/plasma_extension.h new file mode 100644 index 000000000..a5adc9785 --- /dev/null +++ b/src/plasma/plasma_extension.h @@ -0,0 +1,25 @@ +#ifndef PLASMA_EXTENSION_H +#define PLASMA_EXTENSION_H + +static int PyObjectToPlasmaConnection(PyObject *object, + plasma_connection **conn) { + if (PyCapsule_IsValid(object, "plasma")) { + *conn = (plasma_connection *) PyCapsule_GetPointer(object, "plasma"); + return 1; + } else { + PyErr_SetString(PyExc_TypeError, "must be a 'plasma' capsule"); + return 0; + } +} + +static int PyStringToUniqueID(PyObject *object, object_id *object_id) { + if (PyBytes_Check(object)) { + memcpy(&object_id->id[0], PyBytes_AsString(object), UNIQUE_ID_SIZE); + return 1; + } else { + PyErr_SetString(PyExc_TypeError, "must be a 20 character string"); + return 0; + } +} + +#endif /* PLASMA_EXTENSION_H */