mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Fix serialization of numpy scalars and implement more numpy types as well as empty arrays
This commit is contained in:
parent
1951252689
commit
5ecc2ab67d
4 changed files with 89 additions and 112 deletions
|
@ -69,6 +69,9 @@ class Worker(object):
|
|||
result = serialization.Str(result)
|
||||
elif isinstance(result, np.ndarray):
|
||||
result = result.view(serialization.NDArray)
|
||||
elif isinstance(result, np.generic):
|
||||
return result
|
||||
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
||||
elif result == None:
|
||||
return None # can't subclass None and don't need to because there is a global None
|
||||
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
||||
|
|
|
@ -95,6 +95,7 @@ message TaskStatus {
|
|||
message Array {
|
||||
repeated uint64 shape = 1;
|
||||
sint64 dtype = 2;
|
||||
bool is_scalar = 8;
|
||||
repeated double double_data = 3;
|
||||
repeated float float_data = 4;
|
||||
repeated sint64 int_data = 5;
|
||||
|
|
174
src/raylib.cc
174
src/raylib.cc
|
@ -199,6 +199,15 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec
|
|||
|
||||
// Serialization
|
||||
|
||||
#define RAYLIB_SERIALIZE_NPY(TYPE, npy_type, proto_type) \
|
||||
case NPY_##TYPE: { \
|
||||
npy_type* buffer = (npy_type*) PyArray_DATA(array); \
|
||||
for (npy_intp i = 0; i < size; ++i) { \
|
||||
data->add_##proto_type##_data(buffer[i]); \
|
||||
} \
|
||||
} \
|
||||
break;
|
||||
|
||||
// serialize will serialize the python object val into the protocol buffer
|
||||
// object obj, returns 0 if successful and something else if not
|
||||
// FIXME(pcm): This currently only works for contiguous arrays
|
||||
|
@ -263,9 +272,17 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
|||
Ref* data = obj->mutable_objref_data();
|
||||
data->set_data(objref);
|
||||
objrefs.push_back(objref);
|
||||
} else if (PyArray_Check(val)) {
|
||||
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
||||
} else if (PyArray_Check(val) || PyArray_CheckScalar(val)) { // Python int and float already handled
|
||||
Array* data = obj->mutable_array_data();
|
||||
PyArrayObject* array; // will be deallocated at the end
|
||||
if (PyArray_IsScalar(val, Generic)) {
|
||||
data->set_is_scalar(true);
|
||||
PyArray_Descr* descr = PyArray_DescrFromScalar(val); // new reference
|
||||
array = (PyArrayObject*) PyArray_FromScalar(val, descr); // steals the new reference
|
||||
} else { // val is a numpy array
|
||||
array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
||||
}
|
||||
|
||||
npy_intp size = PyArray_SIZE(array);
|
||||
for (int i = 0; i < PyArray_NDIM(array); ++i) {
|
||||
data->add_shape(PyArray_DIM(array, i));
|
||||
|
@ -273,48 +290,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
|||
int typ = PyArray_TYPE(array);
|
||||
data->set_dtype(typ);
|
||||
switch (typ) {
|
||||
case NPY_FLOAT: {
|
||||
npy_float* buffer = (npy_float*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_float_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_DOUBLE: {
|
||||
npy_double* buffer = (npy_double*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_double_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT8: {
|
||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_int_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_INT64: {
|
||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_int_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_uint_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT64: {
|
||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(array);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
data->add_uint_data(buffer[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
RAYLIB_SERIALIZE_NPY(FLOAT, npy_float, float)
|
||||
RAYLIB_SERIALIZE_NPY(DOUBLE, npy_double, double)
|
||||
RAYLIB_SERIALIZE_NPY(INT8, npy_int8, int)
|
||||
RAYLIB_SERIALIZE_NPY(INT16, npy_int16, int)
|
||||
RAYLIB_SERIALIZE_NPY(INT32, npy_int32, int)
|
||||
RAYLIB_SERIALIZE_NPY(INT64, npy_int64, int)
|
||||
RAYLIB_SERIALIZE_NPY(UINT8, npy_uint8, uint)
|
||||
RAYLIB_SERIALIZE_NPY(UINT16, npy_uint16, uint)
|
||||
RAYLIB_SERIALIZE_NPY(UINT32, npy_uint32, uint)
|
||||
RAYLIB_SERIALIZE_NPY(UINT64, npy_uint64, uint)
|
||||
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
|
||||
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
|
||||
while (PyArray_ITER_NOTDONE(iter)) {
|
||||
|
@ -345,6 +330,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
|||
return 0;
|
||||
}
|
||||
|
||||
#define RAYLIB_DESERIALIZE_NPY(TYPE, npy_type, proto_type) \
|
||||
case NPY_##TYPE: { \
|
||||
npy_intp size = array.proto_type##_data_size(); \
|
||||
npy_type* buffer = (npy_type*) PyArray_DATA(pyarray); \
|
||||
for (npy_intp i = 0; i < size; ++i) { \
|
||||
buffer[i] = array.proto_type##_data(i); \
|
||||
} \
|
||||
} \
|
||||
break;
|
||||
|
||||
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
|
||||
PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
|
||||
if (obj.has_int_data()) {
|
||||
|
@ -399,72 +394,35 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjR
|
|||
dims.push_back(array.shape(i));
|
||||
}
|
||||
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
|
||||
if (array.double_data_size() > 0) { // TODO: handle empty array
|
||||
npy_intp size = array.double_data_size();
|
||||
npy_double* buffer = (npy_double*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.double_data(i);
|
||||
}
|
||||
} else if (array.float_data_size() > 0) {
|
||||
npy_intp size = array.float_data_size();
|
||||
npy_float* buffer = (npy_float*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.float_data(i);
|
||||
}
|
||||
} else if (array.int_data_size() > 0) {
|
||||
npy_intp size = array.int_data_size();
|
||||
switch (array.dtype()) {
|
||||
case NPY_INT8: {
|
||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.int_data(i);
|
||||
}
|
||||
switch (array.dtype()) {
|
||||
RAYLIB_DESERIALIZE_NPY(FLOAT, npy_float, float)
|
||||
RAYLIB_DESERIALIZE_NPY(DOUBLE, npy_double, double)
|
||||
RAYLIB_DESERIALIZE_NPY(INT8, npy_int8, int)
|
||||
RAYLIB_DESERIALIZE_NPY(INT16, npy_int16, int)
|
||||
RAYLIB_DESERIALIZE_NPY(INT32, npy_int32, int)
|
||||
RAYLIB_DESERIALIZE_NPY(INT64, npy_int64, int)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT8, npy_uint8, uint)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT16, npy_uint16, uint)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT32, npy_uint32, uint)
|
||||
RAYLIB_DESERIALIZE_NPY(UINT64, npy_uint64, uint)
|
||||
case NPY_OBJECT: {
|
||||
npy_intp size = array.objref_data_size();
|
||||
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
|
||||
objrefs.push_back(array.objref_data(i));
|
||||
}
|
||||
break;
|
||||
case NPY_INT64: {
|
||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.int_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
} else if (array.uint_data_size() > 0) {
|
||||
npy_intp size = array.uint_data_size();
|
||||
switch (array.dtype()) {
|
||||
case NPY_UINT8: {
|
||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.uint_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NPY_UINT64: {
|
||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = array.uint_data(i);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
} else if (array.objref_data_size() > 0) {
|
||||
npy_intp size = array.objref_data_size();
|
||||
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
||||
for (npy_intp i = 0; i < size; ++i) {
|
||||
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
|
||||
objrefs.push_back(array.objref_data(i));
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||
return NULL;
|
||||
}
|
||||
if (array.is_scalar()) {
|
||||
return PyArray_ScalarFromObject((PyObject*) pyarray);
|
||||
} else {
|
||||
return (PyObject*) pyarray;
|
||||
}
|
||||
return (PyObject*) pyarray;
|
||||
} else {
|
||||
PyErr_SetString(RayError, "deserialization: internal error (type not implemented)");
|
||||
return NULL;
|
||||
|
|
|
@ -16,7 +16,10 @@ RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0,
|
|||
(1.0, "hi"), None, (None, None), ("hello", None),
|
||||
True, False, (True, False),
|
||||
{True: "hello", False: "world"},
|
||||
{"hello" : "world", 1: 42, 1.0: 45}, {}]
|
||||
{"hello" : "world", 1: 42, 1.0: 45}, {},
|
||||
np.int8(3), np.int32(4), np.int64(5),
|
||||
np.uint8(3), np.uint32(4), np.uint64(5),
|
||||
np.float32(1.0), np.float64(1.0)]
|
||||
|
||||
class UserDefinedType(object):
|
||||
def __init__(self):
|
||||
|
@ -41,6 +44,16 @@ class SerializationTest(unittest.TestCase):
|
|||
c = serialization.deserialize(worker.handle, b)
|
||||
self.assertTrue((a == c).all())
|
||||
|
||||
a = np.array(0).astype(typ)
|
||||
b, _ = serialization.serialize(worker.handle, a)
|
||||
c = serialization.deserialize(worker.handle, b)
|
||||
self.assertTrue((a == c).all())
|
||||
|
||||
a = np.empty((0,)).astype(typ)
|
||||
b, _ = serialization.serialize(worker.handle, a)
|
||||
c = serialization.deserialize(worker.handle, b)
|
||||
self.assertTrue(a.dtype == c.dtype)
|
||||
|
||||
def testSerialize(self):
|
||||
[w] = services.start_singlenode_cluster(return_drivers=True)
|
||||
|
||||
|
@ -54,8 +67,10 @@ class SerializationTest(unittest.TestCase):
|
|||
|
||||
self.numpyTypeTest(w, 'int8')
|
||||
self.numpyTypeTest(w, 'uint8')
|
||||
# self.numpyTypeTest('int16') # TODO(pcm): implement this
|
||||
# self.numpyTypeTest('int32') # TODO(pcm): implement this
|
||||
self.numpyTypeTest(w, 'int16')
|
||||
self.numpyTypeTest(w, 'uint16')
|
||||
self.numpyTypeTest(w, 'int32')
|
||||
self.numpyTypeTest(w, 'uint32')
|
||||
self.numpyTypeTest(w, 'float32')
|
||||
self.numpyTypeTest(w, 'float64')
|
||||
|
||||
|
@ -311,7 +326,7 @@ class ReferenceCountingTest(unittest.TestCase):
|
|||
objref_val = check_get_deallocated(val)
|
||||
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1)
|
||||
|
||||
if not isinstance(val, bool) and val is not None:
|
||||
if not isinstance(val, bool) and not isinstance(val, np.generic) and val is not None:
|
||||
x, objref_val = check_get_not_deallocated(val)
|
||||
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue