mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Fix serialization of numpy scalars and implement more numpy types as well as empty arrays
This commit is contained in:
parent
1951252689
commit
5ecc2ab67d
4 changed files with 89 additions and 112 deletions
|
@ -69,6 +69,9 @@ class Worker(object):
|
||||||
result = serialization.Str(result)
|
result = serialization.Str(result)
|
||||||
elif isinstance(result, np.ndarray):
|
elif isinstance(result, np.ndarray):
|
||||||
result = result.view(serialization.NDArray)
|
result = result.view(serialization.NDArray)
|
||||||
|
elif isinstance(result, np.generic):
|
||||||
|
return result
|
||||||
|
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
||||||
elif result == None:
|
elif result == None:
|
||||||
return None # can't subclass None and don't need to because there is a global None
|
return None # can't subclass None and don't need to because there is a global None
|
||||||
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
|
||||||
|
|
|
@ -95,6 +95,7 @@ message TaskStatus {
|
||||||
message Array {
|
message Array {
|
||||||
repeated uint64 shape = 1;
|
repeated uint64 shape = 1;
|
||||||
sint64 dtype = 2;
|
sint64 dtype = 2;
|
||||||
|
bool is_scalar = 8;
|
||||||
repeated double double_data = 3;
|
repeated double double_data = 3;
|
||||||
repeated float float_data = 4;
|
repeated float float_data = 4;
|
||||||
repeated sint64 int_data = 5;
|
repeated sint64 int_data = 5;
|
||||||
|
|
174
src/raylib.cc
174
src/raylib.cc
|
@ -199,6 +199,15 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec
|
||||||
|
|
||||||
// Serialization
|
// Serialization
|
||||||
|
|
||||||
|
#define RAYLIB_SERIALIZE_NPY(TYPE, npy_type, proto_type) \
|
||||||
|
case NPY_##TYPE: { \
|
||||||
|
npy_type* buffer = (npy_type*) PyArray_DATA(array); \
|
||||||
|
for (npy_intp i = 0; i < size; ++i) { \
|
||||||
|
data->add_##proto_type##_data(buffer[i]); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
break;
|
||||||
|
|
||||||
// serialize will serialize the python object val into the protocol buffer
|
// serialize will serialize the python object val into the protocol buffer
|
||||||
// object obj, returns 0 if successful and something else if not
|
// object obj, returns 0 if successful and something else if not
|
||||||
// FIXME(pcm): This currently only works for contiguous arrays
|
// FIXME(pcm): This currently only works for contiguous arrays
|
||||||
|
@ -263,9 +272,17 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||||
Ref* data = obj->mutable_objref_data();
|
Ref* data = obj->mutable_objref_data();
|
||||||
data->set_data(objref);
|
data->set_data(objref);
|
||||||
objrefs.push_back(objref);
|
objrefs.push_back(objref);
|
||||||
} else if (PyArray_Check(val)) {
|
} else if (PyArray_Check(val) || PyArray_CheckScalar(val)) { // Python int and float already handled
|
||||||
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
|
||||||
Array* data = obj->mutable_array_data();
|
Array* data = obj->mutable_array_data();
|
||||||
|
PyArrayObject* array; // will be deallocated at the end
|
||||||
|
if (PyArray_IsScalar(val, Generic)) {
|
||||||
|
data->set_is_scalar(true);
|
||||||
|
PyArray_Descr* descr = PyArray_DescrFromScalar(val); // new reference
|
||||||
|
array = (PyArrayObject*) PyArray_FromScalar(val, descr); // steals the new reference
|
||||||
|
} else { // val is a numpy array
|
||||||
|
array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
|
||||||
|
}
|
||||||
|
|
||||||
npy_intp size = PyArray_SIZE(array);
|
npy_intp size = PyArray_SIZE(array);
|
||||||
for (int i = 0; i < PyArray_NDIM(array); ++i) {
|
for (int i = 0; i < PyArray_NDIM(array); ++i) {
|
||||||
data->add_shape(PyArray_DIM(array, i));
|
data->add_shape(PyArray_DIM(array, i));
|
||||||
|
@ -273,48 +290,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||||
int typ = PyArray_TYPE(array);
|
int typ = PyArray_TYPE(array);
|
||||||
data->set_dtype(typ);
|
data->set_dtype(typ);
|
||||||
switch (typ) {
|
switch (typ) {
|
||||||
case NPY_FLOAT: {
|
RAYLIB_SERIALIZE_NPY(FLOAT, npy_float, float)
|
||||||
npy_float* buffer = (npy_float*) PyArray_DATA(array);
|
RAYLIB_SERIALIZE_NPY(DOUBLE, npy_double, double)
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
RAYLIB_SERIALIZE_NPY(INT8, npy_int8, int)
|
||||||
data->add_float_data(buffer[i]);
|
RAYLIB_SERIALIZE_NPY(INT16, npy_int16, int)
|
||||||
}
|
RAYLIB_SERIALIZE_NPY(INT32, npy_int32, int)
|
||||||
}
|
RAYLIB_SERIALIZE_NPY(INT64, npy_int64, int)
|
||||||
break;
|
RAYLIB_SERIALIZE_NPY(UINT8, npy_uint8, uint)
|
||||||
case NPY_DOUBLE: {
|
RAYLIB_SERIALIZE_NPY(UINT16, npy_uint16, uint)
|
||||||
npy_double* buffer = (npy_double*) PyArray_DATA(array);
|
RAYLIB_SERIALIZE_NPY(UINT32, npy_uint32, uint)
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
RAYLIB_SERIALIZE_NPY(UINT64, npy_uint64, uint)
|
||||||
data->add_double_data(buffer[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case NPY_INT8: {
|
|
||||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(array);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
data->add_int_data(buffer[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case NPY_INT64: {
|
|
||||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(array);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
data->add_int_data(buffer[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case NPY_UINT8: {
|
|
||||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(array);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
data->add_uint_data(buffer[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case NPY_UINT64: {
|
|
||||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(array);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
data->add_uint_data(buffer[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
|
case NPY_OBJECT: { // FIXME(pcm): Support arbitrary python objects, not only objrefs
|
||||||
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
|
PyArrayIterObject* iter = (PyArrayIterObject*) PyArray_IterNew((PyObject*)array);
|
||||||
while (PyArray_ITER_NOTDONE(iter)) {
|
while (PyArray_ITER_NOTDONE(iter)) {
|
||||||
|
@ -345,6 +330,16 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define RAYLIB_DESERIALIZE_NPY(TYPE, npy_type, proto_type) \
|
||||||
|
case NPY_##TYPE: { \
|
||||||
|
npy_intp size = array.proto_type##_data_size(); \
|
||||||
|
npy_type* buffer = (npy_type*) PyArray_DATA(pyarray); \
|
||||||
|
for (npy_intp i = 0; i < size; ++i) { \
|
||||||
|
buffer[i] = array.proto_type##_data(i); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
break;
|
||||||
|
|
||||||
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
|
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
|
||||||
PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
|
PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
|
||||||
if (obj.has_int_data()) {
|
if (obj.has_int_data()) {
|
||||||
|
@ -399,72 +394,35 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjR
|
||||||
dims.push_back(array.shape(i));
|
dims.push_back(array.shape(i));
|
||||||
}
|
}
|
||||||
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
|
PyArrayObject* pyarray = (PyArrayObject*) PyArray_SimpleNew(array.shape_size(), &dims[0], array.dtype());
|
||||||
if (array.double_data_size() > 0) { // TODO: handle empty array
|
switch (array.dtype()) {
|
||||||
npy_intp size = array.double_data_size();
|
RAYLIB_DESERIALIZE_NPY(FLOAT, npy_float, float)
|
||||||
npy_double* buffer = (npy_double*) PyArray_DATA(pyarray);
|
RAYLIB_DESERIALIZE_NPY(DOUBLE, npy_double, double)
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
RAYLIB_DESERIALIZE_NPY(INT8, npy_int8, int)
|
||||||
buffer[i] = array.double_data(i);
|
RAYLIB_DESERIALIZE_NPY(INT16, npy_int16, int)
|
||||||
}
|
RAYLIB_DESERIALIZE_NPY(INT32, npy_int32, int)
|
||||||
} else if (array.float_data_size() > 0) {
|
RAYLIB_DESERIALIZE_NPY(INT64, npy_int64, int)
|
||||||
npy_intp size = array.float_data_size();
|
RAYLIB_DESERIALIZE_NPY(UINT8, npy_uint8, uint)
|
||||||
npy_float* buffer = (npy_float*) PyArray_DATA(pyarray);
|
RAYLIB_DESERIALIZE_NPY(UINT16, npy_uint16, uint)
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
RAYLIB_DESERIALIZE_NPY(UINT32, npy_uint32, uint)
|
||||||
buffer[i] = array.float_data(i);
|
RAYLIB_DESERIALIZE_NPY(UINT64, npy_uint64, uint)
|
||||||
}
|
case NPY_OBJECT: {
|
||||||
} else if (array.int_data_size() > 0) {
|
npy_intp size = array.objref_data_size();
|
||||||
npy_intp size = array.int_data_size();
|
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
||||||
switch (array.dtype()) {
|
for (npy_intp i = 0; i < size; ++i) {
|
||||||
case NPY_INT8: {
|
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
|
||||||
npy_int8* buffer = (npy_int8*) PyArray_DATA(pyarray);
|
objrefs.push_back(array.objref_data(i));
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
buffer[i] = array.int_data(i);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
}
|
||||||
case NPY_INT64: {
|
break;
|
||||||
npy_int64* buffer = (npy_int64*) PyArray_DATA(pyarray);
|
default:
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
||||||
buffer[i] = array.int_data(i);
|
return NULL;
|
||||||
}
|
}
|
||||||
}
|
if (array.is_scalar()) {
|
||||||
break;
|
return PyArray_ScalarFromObject((PyObject*) pyarray);
|
||||||
default:
|
} else {
|
||||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
return (PyObject*) pyarray;
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
} else if (array.uint_data_size() > 0) {
|
|
||||||
npy_intp size = array.uint_data_size();
|
|
||||||
switch (array.dtype()) {
|
|
||||||
case NPY_UINT8: {
|
|
||||||
npy_uint8* buffer = (npy_uint8*) PyArray_DATA(pyarray);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
buffer[i] = array.uint_data(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case NPY_UINT64: {
|
|
||||||
npy_uint64* buffer = (npy_uint64*) PyArray_DATA(pyarray);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
buffer[i] = array.uint_data(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
} else if (array.objref_data_size() > 0) {
|
|
||||||
npy_intp size = array.objref_data_size();
|
|
||||||
PyObject** buffer = (PyObject**) PyArray_DATA(pyarray);
|
|
||||||
for (npy_intp i = 0; i < size; ++i) {
|
|
||||||
buffer[i] = make_pyobjref(worker_capsule, array.objref_data(i));
|
|
||||||
objrefs.push_back(array.objref_data(i));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
PyErr_SetString(RayError, "deserialization: internal error (array type not implemented)");
|
|
||||||
return NULL;
|
|
||||||
}
|
}
|
||||||
return (PyObject*) pyarray;
|
|
||||||
} else {
|
} else {
|
||||||
PyErr_SetString(RayError, "deserialization: internal error (type not implemented)");
|
PyErr_SetString(RayError, "deserialization: internal error (type not implemented)");
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
|
@ -16,7 +16,10 @@ RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0,
|
||||||
(1.0, "hi"), None, (None, None), ("hello", None),
|
(1.0, "hi"), None, (None, None), ("hello", None),
|
||||||
True, False, (True, False),
|
True, False, (True, False),
|
||||||
{True: "hello", False: "world"},
|
{True: "hello", False: "world"},
|
||||||
{"hello" : "world", 1: 42, 1.0: 45}, {}]
|
{"hello" : "world", 1: 42, 1.0: 45}, {},
|
||||||
|
np.int8(3), np.int32(4), np.int64(5),
|
||||||
|
np.uint8(3), np.uint32(4), np.uint64(5),
|
||||||
|
np.float32(1.0), np.float64(1.0)]
|
||||||
|
|
||||||
class UserDefinedType(object):
|
class UserDefinedType(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -41,6 +44,16 @@ class SerializationTest(unittest.TestCase):
|
||||||
c = serialization.deserialize(worker.handle, b)
|
c = serialization.deserialize(worker.handle, b)
|
||||||
self.assertTrue((a == c).all())
|
self.assertTrue((a == c).all())
|
||||||
|
|
||||||
|
a = np.array(0).astype(typ)
|
||||||
|
b, _ = serialization.serialize(worker.handle, a)
|
||||||
|
c = serialization.deserialize(worker.handle, b)
|
||||||
|
self.assertTrue((a == c).all())
|
||||||
|
|
||||||
|
a = np.empty((0,)).astype(typ)
|
||||||
|
b, _ = serialization.serialize(worker.handle, a)
|
||||||
|
c = serialization.deserialize(worker.handle, b)
|
||||||
|
self.assertTrue(a.dtype == c.dtype)
|
||||||
|
|
||||||
def testSerialize(self):
|
def testSerialize(self):
|
||||||
[w] = services.start_singlenode_cluster(return_drivers=True)
|
[w] = services.start_singlenode_cluster(return_drivers=True)
|
||||||
|
|
||||||
|
@ -54,8 +67,10 @@ class SerializationTest(unittest.TestCase):
|
||||||
|
|
||||||
self.numpyTypeTest(w, 'int8')
|
self.numpyTypeTest(w, 'int8')
|
||||||
self.numpyTypeTest(w, 'uint8')
|
self.numpyTypeTest(w, 'uint8')
|
||||||
# self.numpyTypeTest('int16') # TODO(pcm): implement this
|
self.numpyTypeTest(w, 'int16')
|
||||||
# self.numpyTypeTest('int32') # TODO(pcm): implement this
|
self.numpyTypeTest(w, 'uint16')
|
||||||
|
self.numpyTypeTest(w, 'int32')
|
||||||
|
self.numpyTypeTest(w, 'uint32')
|
||||||
self.numpyTypeTest(w, 'float32')
|
self.numpyTypeTest(w, 'float32')
|
||||||
self.numpyTypeTest(w, 'float64')
|
self.numpyTypeTest(w, 'float64')
|
||||||
|
|
||||||
|
@ -311,7 +326,7 @@ class ReferenceCountingTest(unittest.TestCase):
|
||||||
objref_val = check_get_deallocated(val)
|
objref_val = check_get_deallocated(val)
|
||||||
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1)
|
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1)
|
||||||
|
|
||||||
if not isinstance(val, bool) and val is not None:
|
if not isinstance(val, bool) and not isinstance(val, np.generic) and val is not None:
|
||||||
x, objref_val = check_get_not_deallocated(val)
|
x, objref_val = check_get_not_deallocated(val)
|
||||||
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)
|
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue