Merge pull request #7 from pcmoritz/numpy-object

Serialize numpy arrays with custom objects
This commit is contained in:
Robert Nishihara 2016-08-30 20:04:47 -07:00 committed by GitHub
commit 3fc9196deb
5 changed files with 44 additions and 13 deletions

View file

@ -1,4 +1,5 @@
#include "numpy.h"
#include "python.h"
#include <sstream>
@ -6,6 +7,11 @@
using namespace arrow;
extern "C" {
extern PyObject *numbuf_serialize_callback;
extern PyObject *numbuf_deserialize_callback;
}
namespace numbuf {
#define ARROW_TYPE_TO_NUMPY_CASE(TYPE) \
@ -52,7 +58,8 @@ Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject**
return Status::OK();
}
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder,
std::vector<PyObject*>& subdicts) {
size_t ndim = PyArray_NDIM(array);
int dtype = PyArray_TYPE(array);
std::vector<int64_t> dims(ndim);
@ -97,9 +104,21 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
RETURN_NOT_OK(builder.AppendTensor(dims, reinterpret_cast<double*>(data)));
break;
default:
std::stringstream stream;
stream << "numpy data type not recognized: " << dtype;
return Status::NotImplemented(stream.str());
if (!numbuf_serialize_callback) {
std::stringstream stream;
stream << "numpy data type not recognized: " << dtype;
return Status::NotImplemented(stream.str());
} else {
PyObject* arglist = Py_BuildValue("(O)", array);
PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist);
if (!result) {
Py_XDECREF(arglist);
return python_error_to_status();
}
builder.AppendDict(PyDict_Size(result));
subdicts.push_back(result);
Py_XDECREF(arglist);
}
}
Py_XDECREF(contiguous);
return Status::OK();

View file

@ -14,7 +14,7 @@
namespace numbuf {
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder);
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, std::vector<PyObject*>& subdicts);
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject** out);
}

View file

@ -109,7 +109,7 @@ Status append(PyObject* elem, SequenceBuilder& builder,
} else if (PyArray_IsScalar(elem, Generic)) {
RETURN_NOT_OK(AppendScalar(elem, builder));
} else if (PyArray_Check(elem)) {
RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder));
RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder, subdicts));
} else if (elem == Py_None) {
RETURN_NOT_OK(builder.AppendNone());
} else {

View file

@ -17,6 +17,8 @@ arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
arrow::Status python_error_to_status();
}
#endif

View file

@ -48,13 +48,6 @@ class SerializationTests(unittest.TestCase):
for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", "float64"]:
self.numpyTest(t)
def testNumpyObject(self):
a = np.array([np.zeros((2,2))], dtype=object)
try:
x = self.roundTripTest([a])
except:
pass
def testRay(self):
for obj in TEST_OBJECTS:
self.roundTripTest([obj])
@ -90,6 +83,23 @@ class SerializationTests(unittest.TestCase):
metadata, size, serialized = libnumbuf.serialize_list([bar])
self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42)
def testObjectArray(self):
x = np.array([1, 2, "hello"], dtype=object)
y = np.array([[1, 2], [3, 4]], dtype=object)
def myserialize(obj):
return {"_pytype_": "numpy.array", "data": obj.tolist()}
def mydeserialize(obj):
if obj["_pytype_"] == "numpy.array":
return np.array(obj["data"], dtype=object)
libnumbuf.register_callbacks(myserialize, mydeserialize)
metadata, size, serialized = libnumbuf.serialize_list([x, y])
assert_equal(libnumbuf.deserialize_list(serialized), [x, y])
def testBuffer(self):
for (i, obj) in enumerate(TEST_OBJECTS):
schema, size, batch = libnumbuf.serialize_list([obj])