diff --git a/python/src/pynumbuf/adapters/numpy.cc b/python/src/pynumbuf/adapters/numpy.cc index fb74db28f..95e90249c 100644 --- a/python/src/pynumbuf/adapters/numpy.cc +++ b/python/src/pynumbuf/adapters/numpy.cc @@ -1,4 +1,5 @@ #include "numpy.h" +#include "python.h" #include @@ -6,6 +7,11 @@ using namespace arrow; +extern "C" { + extern PyObject *numbuf_serialize_callback; + extern PyObject *numbuf_deserialize_callback; +} + namespace numbuf { #define ARROW_TYPE_TO_NUMPY_CASE(TYPE) \ @@ -52,7 +58,8 @@ Status DeserializeArray(std::shared_ptr array, int32_t offset, PyObject** return Status::OK(); } -Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) { +Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, + std::vector& subdicts) { size_t ndim = PyArray_NDIM(array); int dtype = PyArray_TYPE(array); std::vector dims(ndim); @@ -97,9 +104,21 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) { RETURN_NOT_OK(builder.AppendTensor(dims, reinterpret_cast(data))); break; default: - std::stringstream stream; - stream << "numpy data type not recognized: " << dtype; - return Status::NotImplemented(stream.str()); + if (!numbuf_serialize_callback) { + std::stringstream stream; + stream << "numpy data type not recognized: " << dtype; + return Status::NotImplemented(stream.str()); + } else { + PyObject* arglist = Py_BuildValue("(O)", array); + PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist); + if (!result) { + Py_XDECREF(arglist); + return python_error_to_status(); + } + builder.AppendDict(PyDict_Size(result)); + subdicts.push_back(result); + Py_XDECREF(arglist); + } } Py_XDECREF(contiguous); return Status::OK(); diff --git a/python/src/pynumbuf/adapters/numpy.h b/python/src/pynumbuf/adapters/numpy.h index 8b8b3859a..a57e5c974 100644 --- a/python/src/pynumbuf/adapters/numpy.h +++ b/python/src/pynumbuf/adapters/numpy.h @@ -14,7 +14,7 @@ namespace numbuf { -arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder); +arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, std::vector& subdicts); arrow::Status DeserializeArray(std::shared_ptr array, int32_t offset, PyObject** out); } diff --git a/python/src/pynumbuf/adapters/python.cc b/python/src/pynumbuf/adapters/python.cc index 89177e425..3b9b6b2c2 100644 --- a/python/src/pynumbuf/adapters/python.cc +++ b/python/src/pynumbuf/adapters/python.cc @@ -109,7 +109,7 @@ Status append(PyObject* elem, SequenceBuilder& builder, } else if (PyArray_IsScalar(elem, Generic)) { RETURN_NOT_OK(AppendScalar(elem, builder)); } else if (PyArray_Check(elem)) { - RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder)); + RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder, subdicts)); } else if (elem == Py_None) { RETURN_NOT_OK(builder.AppendNone()); } else { diff --git a/python/src/pynumbuf/adapters/python.h b/python/src/pynumbuf/adapters/python.h index 1fa07f920..66efa3005 100644 --- a/python/src/pynumbuf/adapters/python.h +++ b/python/src/pynumbuf/adapters/python.h @@ -17,6 +17,8 @@ arrow::Status DeserializeList(std::shared_ptr array, int32_t start arrow::Status DeserializeTuple(std::shared_ptr array, int32_t start_idx, int32_t stop_idx, PyObject** out); arrow::Status DeserializeDict(std::shared_ptr array, int32_t start_idx, int32_t stop_idx, PyObject** out); +arrow::Status python_error_to_status(); + } #endif diff --git a/python/test/runtest.py b/python/test/runtest.py index f4e161647..c1b7d21e8 100644 --- a/python/test/runtest.py +++ b/python/test/runtest.py @@ -48,13 +48,6 @@ class SerializationTests(unittest.TestCase): for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", "float64"]: self.numpyTest(t) - def testNumpyObject(self): - a = np.array([np.zeros((2,2))], dtype=object) - try: - x = self.roundTripTest([a]) - except: - pass - def testRay(self): for obj in TEST_OBJECTS: self.roundTripTest([obj]) @@ -90,6 +83,23 @@ class SerializationTests(unittest.TestCase): metadata, size, serialized = libnumbuf.serialize_list([bar]) self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42) + def testObjectArray(self): + x = np.array([1, 2, "hello"], dtype=object) + y = np.array([[1, 2], [3, 4]], dtype=object) + + def myserialize(obj): + return {"_pytype_": "numpy.array", "data": obj.tolist()} + + def mydeserialize(obj): + if obj["_pytype_"] == "numpy.array": + return np.array(obj["data"], dtype=object) + + libnumbuf.register_callbacks(myserialize, mydeserialize) + + metadata, size, serialized = libnumbuf.serialize_list([x, y]) + + assert_equal(libnumbuf.deserialize_list(serialized), [x, y]) + def testBuffer(self): for (i, obj) in enumerate(TEST_OBJECTS): schema, size, batch = libnumbuf.serialize_list([obj])