mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Merge pull request #5 from pcmoritz/callback
add custom callbacks for serialization
This commit is contained in:
commit
8c1f0f289e
3 changed files with 100 additions and 4 deletions
|
@ -6,6 +6,9 @@
|
|||
|
||||
using namespace arrow;
|
||||
|
||||
extern PyObject* numbuf_serialize_callback;
|
||||
extern PyObject* numbuf_deserialize_callback;
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
|
||||
|
@ -49,6 +52,17 @@ PyObject* get_value(ArrayPtr arr, int32_t index, int32_t type) {
|
|||
return NULL;
|
||||
}
|
||||
|
||||
Status python_error_to_status() {
|
||||
PyObject *type, *value, *traceback;
|
||||
PyErr_Fetch(&type, &value, &traceback);
|
||||
char *err_message = PyString_AsString(value);
|
||||
std::stringstream ss;
|
||||
if (err_message) {
|
||||
ss << "Python error in callback: " << err_message;
|
||||
}
|
||||
return Status::NotImplemented(ss.str());
|
||||
}
|
||||
|
||||
Status append(PyObject* elem, SequenceBuilder& builder,
|
||||
std::vector<PyObject*>& sublists,
|
||||
std::vector<PyObject*>& subtuples,
|
||||
|
@ -99,10 +113,22 @@ Status append(PyObject* elem, SequenceBuilder& builder,
|
|||
} else if (elem == Py_None) {
|
||||
RETURN_NOT_OK(builder.AppendNone());
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem))
|
||||
<< " not recognized";
|
||||
return Status::NotImplemented(ss.str());
|
||||
if (!numbuf_serialize_callback) {
|
||||
std::stringstream ss;
|
||||
ss << "data type of " << PyString_AS_STRING(PyObject_Repr(elem))
|
||||
<< " not recognized and custom serialization handler not registered";
|
||||
return Status::NotImplemented(ss.str());
|
||||
} else {
|
||||
PyObject* arglist = Py_BuildValue("(O)", elem);
|
||||
PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist);
|
||||
if (!result) {
|
||||
Py_XDECREF(arglist);
|
||||
return python_error_to_status();
|
||||
}
|
||||
builder.AppendDict(PyDict_Size(result));
|
||||
subdicts.push_back(result);
|
||||
Py_XDECREF(arglist);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -213,6 +239,16 @@ Status DeserializeDict(std::shared_ptr<Array> array, int32_t start_idx, int32_t
|
|||
}
|
||||
Py_XDECREF(keys); // PyList_GetItem(keys, ...) incremented the reference count
|
||||
Py_XDECREF(vals); // PyList_GetItem(vals, ...) incremented the reference count
|
||||
static PyObject* py_type = PyString_FromString("_pytype_");
|
||||
if (PyDict_Contains(result, py_type) && numbuf_deserialize_callback) {
|
||||
PyObject* arglist = Py_BuildValue("(O)", result);
|
||||
result = PyObject_CallObject(numbuf_deserialize_callback, arglist);
|
||||
if (!result) {
|
||||
Py_XDECREF(arglist);
|
||||
return python_error_to_status();
|
||||
}
|
||||
Py_XDECREF(arglist);
|
||||
}
|
||||
*out = result;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -26,6 +26,9 @@ extern "C" {
|
|||
|
||||
static PyObject *NumbufError;
|
||||
|
||||
PyObject *numbuf_serialize_callback = NULL;
|
||||
PyObject *numbuf_deserialize_callback = NULL;
|
||||
|
||||
int PyObjectToArrow(PyObject* object, std::shared_ptr<RowBatch> **result) {
|
||||
if (PyCapsule_IsValid(object, "arrow")) {
|
||||
*result = reinterpret_cast<std::shared_ptr<RowBatch>*>(PyCapsule_GetPointer(object, "arrow"));
|
||||
|
@ -131,11 +134,37 @@ static PyObject* deserialize_list(PyObject* self, PyObject* args) {
|
|||
return result;
|
||||
}
|
||||
|
||||
static PyObject* register_callbacks(PyObject* self, PyObject* args) {
|
||||
PyObject* result = NULL;
|
||||
PyObject* serialize_callback;
|
||||
PyObject* deserialize_callback;
|
||||
if (PyArg_ParseTuple(args, "OO:register_callbacks", &serialize_callback, &deserialize_callback)) {
|
||||
if (!PyCallable_Check(serialize_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "serialize_callback must be callable");
|
||||
return NULL;
|
||||
}
|
||||
if (!PyCallable_Check(deserialize_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "deserialize_callback must be callable");
|
||||
return NULL;
|
||||
}
|
||||
Py_XINCREF(serialize_callback); // Add a reference to new serialization callback
|
||||
Py_XINCREF(deserialize_callback); // Add a reference to new deserialization callback
|
||||
Py_XDECREF(numbuf_serialize_callback); // Dispose of old serialization callback
|
||||
Py_XDECREF(numbuf_deserialize_callback); // Dispose of old deserialization callback
|
||||
numbuf_serialize_callback = serialize_callback;
|
||||
numbuf_deserialize_callback = deserialize_callback;
|
||||
Py_INCREF(Py_None);
|
||||
result = Py_None;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyMethodDef NumbufMethods[] = {
|
||||
{ "serialize_list", serialize_list, METH_VARARGS, "serialize a Python list" },
|
||||
{ "deserialize_list", deserialize_list, METH_VARARGS, "deserialize a Python list" },
|
||||
{ "write_to_buffer", write_to_buffer, METH_VARARGS, "write serialized data to buffer"},
|
||||
{ "read_from_buffer", read_from_buffer, METH_VARARGS, "read serialized data from buffer"},
|
||||
{ "register_callbacks", register_callbacks, METH_VARARGS, "set serialization and deserialization callbacks"},
|
||||
{ NULL, NULL, 0, NULL }
|
||||
};
|
||||
|
||||
|
|
|
@ -59,6 +59,37 @@ class SerializationTests(unittest.TestCase):
|
|||
for obj in TEST_OBJECTS:
|
||||
self.roundTripTest([obj])
|
||||
|
||||
def testCallback(self):
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
class Bar(object):
|
||||
def __init__(self):
|
||||
self.foo = Foo()
|
||||
|
||||
def serialize(obj):
|
||||
return dict(obj.__dict__, **{"_pytype_": type(obj).__name__})
|
||||
|
||||
def deserialize(obj):
|
||||
if obj["_pytype_"] == "Foo":
|
||||
result = Foo()
|
||||
elif obj["_pytype_"] == "Bar":
|
||||
result = Bar()
|
||||
|
||||
obj.pop("_pytype_", None)
|
||||
result.__dict__ = obj
|
||||
return result
|
||||
|
||||
bar = Bar()
|
||||
bar.foo.x = 42
|
||||
|
||||
libnumbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
metadata, size, serialized = libnumbuf.serialize_list([bar])
|
||||
self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42)
|
||||
|
||||
def testBuffer(self):
|
||||
for (i, obj) in enumerate(TEST_OBJECTS):
|
||||
schema, size, batch = libnumbuf.serialize_list([obj])
|
||||
|
|
Loading…
Add table
Reference in a new issue