diff --git a/protos/types.proto b/protos/types.proto index e1a67aa5e..da383377a 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -20,6 +20,10 @@ message Bool { bool data = 1; } +message Ref { + uint64 data = 1; +} + message PyObj { bytes data = 1; } @@ -35,6 +39,7 @@ message Obj { Dict dict_data = 8; Array array_data = 5; Empty empty_data = 9; + Ref objref_data = 11; PyObj pyobj_data = 6; } diff --git a/src/raylib.cc b/src/raylib.cc index 39aaee13d..8128f871b 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -241,6 +241,11 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vectormutable_string_data()->set_data(buffer, length); } else if (val == Py_None) { obj->mutable_empty_data(); // allocate an Empty object, this is a None + } else if (PyObject_IsInstance(val, (PyObject*) &PyObjRefType)) { + ObjRef objref = ((PyObjRef*) val)->val; + Ref* data = obj->mutable_objref_data(); + data->set_data(objref); + objrefs.push_back(objref); } else if (PyArray_Check(val)) { PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val); Array* data = obj->mutable_array_data(); @@ -365,6 +370,9 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector dims; diff --git a/test/runtest.py b/test/runtest.py index 4be6a907a..498a1a0d7 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -60,11 +60,17 @@ class SerializationTest(unittest.TestCase): ref1 = ray.push(0, w) ref2 = ray.push(0, w) ref3 = ray.push(0, w) + a = np.array([[ref0, ref1], [ref2, ref3]]) capsule, _ = serialization.serialize(w.handle, a) result = serialization.deserialize(w.handle, capsule) self.assertTrue((a == result).all()) + self.roundTripTest(w, ref0) + self.roundTripTest(w, [ref0, ref1, ref2, ref3]) + self.roundTripTest(w, {'0': ref0, '1': ref1, '2': ref2, '3': ref3}) + self.roundTripTest(w, (ref0, 1)) + services.cleanup() class ObjStoreTest(unittest.TestCase):