Fix serialization of numpy scalars and implement more numpy types as well as empty arrays

This commit is contained in:
Philipp Moritz 2016-06-23 18:39:02 -07:00
parent 1951252689
commit 5ecc2ab67d
4 changed files with 89 additions and 112 deletions

View file

@ -69,6 +69,9 @@ class Worker(object):
result = serialization.Str(result)
elif isinstance(result, np.ndarray):
result = result.view(serialization.NDArray)
elif isinstance(result, np.generic):
return result
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
elif result == None:
return None # can't subclass None and don't need to because there is a global None
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)

View file

@ -95,6 +95,7 @@ message TaskStatus {
message Array {
repeated uint64 shape = 1;
sint64 dtype = 2;
bool is_scalar = 8;
repeated double double_data = 3;
repeated float float_data = 4;
repeated sint64 int_data = 5;

View file

@ -199,6 +199,15 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec
// Serialization
#define RAYLIB_SERIALIZE_NPY(TYPE, npy_type, proto_type) \
case NPY_##TYPE: { \
npy_type* buffer = (npy_type*) PyArray_DATA(array); \
for (npy_intp i = 0; i < size; ++i) { \
data->add_##proto_type##_data(buffer[i]); \
} \
} \
break;
// serialize will serialize the python object val into the protocol buffer
// object obj, returns 0 if successful and something else if not
// FIXME(pcm): This currently only works for contiguous arrays
@ -263,9 +272,17 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
Ref* data = obj->mutable_objref_data();
data->set_data(objref);
objrefs.push_back(objref);
} else if (PyArray_Check(val)) {
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
} else if (PyArray_Check(val) || PyArray_CheckScalar(val)) { // Python int and float already handled
Array* data = obj->mutable_array_data();
PyArrayObject* array; // will be deallocated at the end
if (PyArray_IsScalar(val, Generic)) {
data->set_is_scalar(true);
PyArray_Descr* descr = PyArray_DescrFromScalar(val); // new reference
array = (PyArrayObject*) PyArray_FromScalar(val, descr); // steals the new reference
} else { // val is a numpy array
array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
}
npy_intp size = PyArray_SIZE(array);
for (int i = 0; i < PyArray_NDIM(array); ++i) {
data->add_shape(PyArray_DIM(array, i));
@ -273,48 +290,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
int typ = PyArray_TYPE(array);
data->set_dtype(typ);
switch (typ) {
case NPY_FLOAT: {
npy_float* buffer = (npy_float*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_float_data(buffer[i]);
}
}
break;
case NPY_DOUBLE: {
npy_double* buffer = (npy_double*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_double_data(buffer[i]);
}
}
break;
case NPY_INT8: {
npy_int8* buffer = (npy_int8*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_int_data(buffer[i]);
}
}
break;
case NPY_INT64: {
npy_int64* buffer = (npy_int64*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_int_data(buffer[i]);
}
}
break;
case NPY_UINT8: {
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_uint_data(buffer[i]);
}
}
break;
case NPY_UINT64: {
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(array);
for (npy_intp i = 0; i < size; ++i) {
data->add_uint_data(buffer[i]);
}
}
break;
RAYLIB_SERIALIZE_NPY(FLOAT, npy_float, float)
RAYLIB_SERIALIZE_NPY(DOUBLE, npy_double, double)
RAYLIB_SERIALIZE_NPY(INT8, npy_int8, int)
RAYLIB_SERIALIZE_NPY(INT16, npy_int16, int)
RAYLIB_SERIALIZE_NPY(INT32, npy_int32, int)
RAYLIB_SERIALIZE_NPY(INT64, npy_int64, int)
RAYLIB_SERIALIZE_NPY(UINT8, npy_uint8, uint)
RAYLIB_SERIALIZE_NPY(UINT16, npy_uint16, uint)
RAYLIB_SERIALIZE_NPY(UINT32, npy_uint32, uint)
RAYLIB_SERIALIZE_NPY(UINT64, npy_uint64, uint)
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
while (PyArray_ITER_NOTDONE(iter)) {
@ -345,6 +330,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
return 0;
}
#define RAYLIB_DESERIALIZE_NPY(TYPE, npy_type, proto_type) \
case NPY_##TYPE: { \
npy_intp size = array.proto_type##_data_size(); \
npy_type* buffer = (npy_type*) PyArray_DATA(pyarray); \
for (npy_intp i = 0; i < size; ++i) { \
buffer[i] = array.proto_type##_data(i); \
} \
} \
break;
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
if (obj.has_int_data()) {
@ -399,72 +394,35 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjR
dims.push_back(array.shape(i));
}
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
if (array.double_data_size() > 0) { // TODO: handle empty array
npy_intp size = array.double_data_size();
npy_double* buffer = (npy_double*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.double_data(i);
}
} else if (array.float_data_size() > 0) {
npy_intp size = array.float_data_size();
npy_float* buffer = (npy_float*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.float_data(i);
}
} else if (array.int_data_size() > 0) {
npy_intp size = array.int_data_size();
switch (array.dtype()) {
case NPY_INT8: {
npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.int_data(i);
}
switch (array.dtype()) {
RAYLIB_DESERIALIZE_NPY(FLOAT, npy_float, float)
RAYLIB_DESERIALIZE_NPY(DOUBLE, npy_double, double)
RAYLIB_DESERIALIZE_NPY(INT8, npy_int8, int)
RAYLIB_DESERIALIZE_NPY(INT16, npy_int16, int)
RAYLIB_DESERIALIZE_NPY(INT32, npy_int32, int)
RAYLIB_DESERIALIZE_NPY(INT64, npy_int64, int)
RAYLIB_DESERIALIZE_NPY(UINT8, npy_uint8, uint)
RAYLIB_DESERIALIZE_NPY(UINT16, npy_uint16, uint)
RAYLIB_DESERIALIZE_NPY(UINT32, npy_uint32, uint)
RAYLIB_DESERIALIZE_NPY(UINT64, npy_uint64, uint)
case NPY_OBJECT: {
npy_intp size = array.objref_data_size();
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
objrefs.push_back(array.objref_data(i));
}
break;
case NPY_INT64: {
npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.int_data(i);
}
}
break;
default:
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
return NULL;
}
} else if (array.uint_data_size() > 0) {
npy_intp size = array.uint_data_size();
switch (array.dtype()) {
case NPY_UINT8: {
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.uint_data(i);
}
}
break;
case NPY_UINT64: {
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = array.uint_data(i);
}
}
break;
default:
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
return NULL;
}
} else if (array.objref_data_size() > 0) {
npy_intp size = array.objref_data_size();
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
for (npy_intp i = 0; i < size; ++i) {
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
objrefs.push_back(array.objref_data(i));
}
} else {
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
return NULL;
}
break;
default:
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
return NULL;
}
if (array.is_scalar()) {
return PyArray_ScalarFromObject((PyObject*) pyarray);
} else {
return (PyObject*) pyarray;
}
return (PyObject*) pyarray;
} else {
PyErr_SetString(RayError, "deserialization: internal error (type not implemented)");
return NULL;

View file

@ -16,7 +16,10 @@ RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0,
(1.0, "hi"), None, (None, None), ("hello", None),
True, False, (True, False),
{True: "hello", False: "world"},
{"hello" : "world", 1: 42, 1.0: 45}, {}]
{"hello" : "world", 1: 42, 1.0: 45}, {},
np.int8(3), np.int32(4), np.int64(5),
np.uint8(3), np.uint32(4), np.uint64(5),
np.float32(1.0), np.float64(1.0)]
class UserDefinedType(object):
def __init__(self):
@ -41,6 +44,16 @@ class SerializationTest(unittest.TestCase):
c = serialization.deserialize(worker.handle, b)
self.assertTrue((a == c).all())
a = np.array(0).astype(typ)
b, _ = serialization.serialize(worker.handle, a)
c = serialization.deserialize(worker.handle, b)
self.assertTrue((a == c).all())
a = np.empty((0,)).astype(typ)
b, _ = serialization.serialize(worker.handle, a)
c = serialization.deserialize(worker.handle, b)
self.assertTrue(a.dtype == c.dtype)
def testSerialize(self):
[w] = services.start_singlenode_cluster(return_drivers=True)
@ -54,8 +67,10 @@ class SerializationTest(unittest.TestCase):
self.numpyTypeTest(w, 'int8')
self.numpyTypeTest(w, 'uint8')
# self.numpyTypeTest('int16') # TODO(pcm): implement this
# self.numpyTypeTest('int32') # TODO(pcm): implement this
self.numpyTypeTest(w, 'int16')
self.numpyTypeTest(w, 'uint16')
self.numpyTypeTest(w, 'int32')
self.numpyTypeTest(w, 'uint32')
self.numpyTypeTest(w, 'float32')
self.numpyTypeTest(w, 'float64')
@ -311,7 +326,7 @@ class ReferenceCountingTest(unittest.TestCase):
objref_val = check_get_deallocated(val)
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1)
if not isinstance(val, bool) and val is not None:
if not isinstance(val, bool) and not isinstance(val, np.generic) and val is not None:
x, objref_val = check_get_not_deallocated(val)
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)