mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Merge pull request #7 from pcmoritz/numpy-object
Serialize numpy arrays with custom objects
This commit is contained in:
commit
3fc9196deb
5 changed files with 44 additions and 13 deletions
|
@ -1,4 +1,5 @@
|
|||
#include "numpy.h"
|
||||
#include "python.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
|
@ -6,6 +7,11 @@
|
|||
|
||||
using namespace arrow;
|
||||
|
||||
extern "C" {
|
||||
extern PyObject *numbuf_serialize_callback;
|
||||
extern PyObject *numbuf_deserialize_callback;
|
||||
}
|
||||
|
||||
namespace numbuf {
|
||||
|
||||
#define ARROW_TYPE_TO_NUMPY_CASE(TYPE) \
|
||||
|
@ -52,7 +58,8 @@ Status DeserializeArray(std::shared_ptr<Array> array, int32_t offset, PyObject**
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
|
||||
Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder,
|
||||
std::vector<PyObject*>& subdicts) {
|
||||
size_t ndim = PyArray_NDIM(array);
|
||||
int dtype = PyArray_TYPE(array);
|
||||
std::vector<int64_t> dims(ndim);
|
||||
|
@ -97,9 +104,21 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder) {
|
|||
RETURN_NOT_OK(builder.AppendTensor(dims, reinterpret_cast<double*>(data)));
|
||||
break;
|
||||
default:
|
||||
std::stringstream stream;
|
||||
stream << "numpy data type not recognized: " << dtype;
|
||||
return Status::NotImplemented(stream.str());
|
||||
if (!numbuf_serialize_callback) {
|
||||
std::stringstream stream;
|
||||
stream << "numpy data type not recognized: " << dtype;
|
||||
return Status::NotImplemented(stream.str());
|
||||
} else {
|
||||
PyObject* arglist = Py_BuildValue("(O)", array);
|
||||
PyObject* result = PyObject_CallObject(numbuf_serialize_callback, arglist);
|
||||
if (!result) {
|
||||
Py_XDECREF(arglist);
|
||||
return python_error_to_status();
|
||||
}
|
||||
builder.AppendDict(PyDict_Size(result));
|
||||
subdicts.push_back(result);
|
||||
Py_XDECREF(arglist);
|
||||
}
|
||||
}
|
||||
Py_XDECREF(contiguous);
|
||||
return Status::OK();
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace numbuf {
|
||||
|
||||
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder);
|
||||
arrow::Status SerializeArray(PyArrayObject* array, SequenceBuilder& builder, std::vector<PyObject*>& subdicts);
|
||||
arrow::Status DeserializeArray(std::shared_ptr<arrow::Array> array, int32_t offset, PyObject** out);
|
||||
|
||||
}
|
||||
|
|
|
@ -109,7 +109,7 @@ Status append(PyObject* elem, SequenceBuilder& builder,
|
|||
} else if (PyArray_IsScalar(elem, Generic)) {
|
||||
RETURN_NOT_OK(AppendScalar(elem, builder));
|
||||
} else if (PyArray_Check(elem)) {
|
||||
RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder));
|
||||
RETURN_NOT_OK(SerializeArray((PyArrayObject*) elem, builder, subdicts));
|
||||
} else if (elem == Py_None) {
|
||||
RETURN_NOT_OK(builder.AppendNone());
|
||||
} else {
|
||||
|
|
|
@ -17,6 +17,8 @@ arrow::Status DeserializeList(std::shared_ptr<arrow::Array> array, int32_t start
|
|||
arrow::Status DeserializeTuple(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
arrow::Status DeserializeDict(std::shared_ptr<arrow::Array> array, int32_t start_idx, int32_t stop_idx, PyObject** out);
|
||||
|
||||
arrow::Status python_error_to_status();
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -48,13 +48,6 @@ class SerializationTests(unittest.TestCase):
|
|||
for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", "float64"]:
|
||||
self.numpyTest(t)
|
||||
|
||||
def testNumpyObject(self):
|
||||
a = np.array([np.zeros((2,2))], dtype=object)
|
||||
try:
|
||||
x = self.roundTripTest([a])
|
||||
except:
|
||||
pass
|
||||
|
||||
def testRay(self):
|
||||
for obj in TEST_OBJECTS:
|
||||
self.roundTripTest([obj])
|
||||
|
@ -90,6 +83,23 @@ class SerializationTests(unittest.TestCase):
|
|||
metadata, size, serialized = libnumbuf.serialize_list([bar])
|
||||
self.assertEqual(libnumbuf.deserialize_list(serialized)[0].foo.x, 42)
|
||||
|
||||
def testObjectArray(self):
|
||||
x = np.array([1, 2, "hello"], dtype=object)
|
||||
y = np.array([[1, 2], [3, 4]], dtype=object)
|
||||
|
||||
def myserialize(obj):
|
||||
return {"_pytype_": "numpy.array", "data": obj.tolist()}
|
||||
|
||||
def mydeserialize(obj):
|
||||
if obj["_pytype_"] == "numpy.array":
|
||||
return np.array(obj["data"], dtype=object)
|
||||
|
||||
libnumbuf.register_callbacks(myserialize, mydeserialize)
|
||||
|
||||
metadata, size, serialized = libnumbuf.serialize_list([x, y])
|
||||
|
||||
assert_equal(libnumbuf.deserialize_list(serialized), [x, y])
|
||||
|
||||
def testBuffer(self):
|
||||
for (i, obj) in enumerate(TEST_OBJECTS):
|
||||
schema, size, batch = libnumbuf.serialize_list([obj])
|
||||
|
|
Loading…
Add table
Reference in a new issue