Merge pull request #5 from pcmoritz/callback

add custom callbacks for serialization
This commit is contained in:
Robert Nishihara 2016-08-30 15:55:34 -07:00 committed by GitHub
commit 8c1f0f289e
3 changed files with 100 additions and 4 deletions

View file

@ -6,6 +6,9 @@
using namespace arrow;
extern PyObject* numbuf_serialize_callback;
extern PyObject* numbuf_deserialize_callback;
namespace numbuf {
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
@ -49,6 +52,17 @@ PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
return NULL;
}
Status python_error_to_status() {
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
char *err_message = PyString_AsString(value);
std::stringstream ss;
if (err_message) {
ss << "Python error in callback: " << err_message;
}
return Status::NotImplemented(ss.str());
}
Status append(PyObject* elem, SequenceBuilder& builder,
std::vector<PyObject*>& sublists,
std::vector<PyObject*>& subtuples,
@ -99,10 +113,22 @@ Status append(PyObject* elem, SequenceBuilder& builder,
} else if (elem == Py_None) {
RETURN_NOT_OK(builder.AppendNone());
} else {
std::stringstream ss;
ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem))
<< " not recognized";
return Status::NotImplemented(ss.str());
if (!numbuf_serialize_callback) {
std::stringstream ss;
ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem))
<< " not recognized and custom serialization handler not registered";
return Status::NotImplemented(ss.str());
} else {
PyObject* arglist = Py_BuildValue("(O)", elem);
PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist);
if (!result) {
Py_XDECREF(arglist);
return python_error_to_status();
}
builder.AppendDict(PyDict_Size(result));
subdicts.push_back(result);
Py_XDECREF(arglist);
}
}
return Status::OK();
}
@ -213,6 +239,16 @@ Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t
}
Py_XDECREF(keys); // PyList_GetItem(keys, ...) incremented the reference count
Py_XDECREF(vals); // PyList_GetItem(vals, ...) incremented the reference count
static PyObject* py_type = PyString_FromString("_pytype_");
if (PyDict_Contains(result, py_type) && numbuf_deserialize_callback) {
PyObject* arglist = Py_BuildValue("(O)", result);
result = PyObject_CallObject(numbuf_deserialize_callback, arglist);
if (!result) {
Py_XDECREF(arglist);
return python_error_to_status();
}
Py_XDECREF(arglist);
}
*out = result;
return Status::OK();
}

View file

@ -26,6 +26,9 @@ extern "C" {
static PyObject *NumbufError;
PyObject *numbuf_serialize_callback = NULL;
PyObject *numbuf_deserialize_callback = NULL;
int PyObjectToArrow(PyObject* object, std::shared_ptr<RowBatch> **result) {
if (PyCapsule_IsValid(object, "arrow")) {
*result = reinterpret_cast<std::shared_ptr<RowBatch>*>(PyCapsule_GetPointer(object, "arrow"));
@ -131,11 +134,37 @@ static PyObject* deserialize_list(PyObject* self, PyObject* args) {
return result;
}
static PyObject* register_callbacks(PyObject* self, PyObject* args) {
PyObject* result = NULL;
PyObject* serialize_callback;
PyObject* deserialize_callback;
if (PyArg_ParseTuple(args, "OO:register_callbacks", &serialize_callback, &deserialize_callback)) {
if (!PyCallable_Check(serialize_callback)) {
PyErr_SetString(PyExc_TypeError, "serialize_callback must be callable");
return NULL;
}
if (!PyCallable_Check(deserialize_callback)) {
PyErr_SetString(PyExc_TypeError, "deserialize_callback must be callable");
return NULL;
}
Py_XINCREF(serialize_callback); // Add a reference to new serialization callback
Py_XINCREF(deserialize_callback); // Add a reference to new deserialization callback
Py_XDECREF(numbuf_serialize_callback); // Dispose of old serialization callback
Py_XDECREF(numbuf_deserialize_callback); // Dispose of old deserialization callback
numbuf_serialize_callback = serialize_callback;
numbuf_deserialize_callback = deserialize_callback;
Py_INCREF(Py_None);
result = Py_None;
}
return result;
}
static PyMethodDef NumbufMethods[] = {
{ "serialize_list", serialize_list, METH_VARARGS, "serialize a Python list" },
{ "deserialize_list", deserialize_list, METH_VARARGS, "deserialize a Python list" },
{ "write_to_buffer", write_to_buffer, METH_VARARGS, "write serialized data to buffer"},
{ "read_from_buffer", read_from_buffer, METH_VARARGS, "read serialized data from buffer"},
{ "register_callbacks", register_callbacks, METH_VARARGS, "set serialization and deserialization callbacks"},
{ NULL, NULL, 0, NULL }
};

View file

@ -59,6 +59,37 @@ class SerializationTests(unittest.TestCase):
for obj in TEST_OBJECTS:
self.roundTripTest([obj])
def testCallback(self):
class Foo(object):
def __init__(self):
self.x = 1
class Bar(object):
def __init__(self):
self.foo = Foo()
def serialize(obj):
return dict(obj.__dict__, **{"_pytype_": type(obj).__name__})
def deserialize(obj):
if obj["_pytype_"] == "Foo":
result = Foo()
elif obj["_pytype_"] == "Bar":
result = Bar()
obj.pop("_pytype_", None)
result.__dict__ = obj
return result
bar = Bar()
bar.foo.x = 42
libnumbuf.register_callbacks(serialize, deserialize)
metadata, size, serialized = libnumbuf.serialize_list([bar])
self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42)
def testBuffer(self):
for (i, obj) in enumerate(TEST_OBJECTS):
schema, size, batch = libnumbuf.serialize_list([obj])