implement serialization of object references inside of python objects

This commit is contained in:
Philipp Moritz 2016-06-10 16:32:48 -07:00
parent 4cc024ae36
commit 19e7f0d72d
3 changed files with 19 additions and 0 deletions

View file

@ -20,6 +20,10 @@ message Bool {
bool data = 1; bool data = 1;
} }
message Ref {
uint64 data = 1;
}
message PyObj { message PyObj {
bytes data = 1; bytes data = 1;
} }
@ -35,6 +39,7 @@ message Obj {
Dict dict_data = 8; Dict dict_data = 8;
Array array_data = 5; Array array_data = 5;
Empty empty_data = 9; Empty empty_data = 9;
Ref objref_data = 11;
PyObj pyobj_data = 6; PyObj pyobj_data = 6;
} }

View file

@ -241,6 +241,11 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
obj->mutable_string_data()->set_data(buffer, length); obj->mutable_string_data()->set_data(buffer, length);
} else if (val == Py_None) { } else if (val == Py_None) {
obj->mutable_empty_data(); // allocate an Empty object, this is a None obj->mutable_empty_data(); // allocate an Empty object, this is a None
} else if (PyObject_IsInstance(val, (PyObject*) &PyObjRefType)) {
ObjRef objref = ((PyObjRef*) val)->val;
Ref* data = obj->mutable_objref_data();
data->set_data(objref);
objrefs.push_back(objref);
} else if (PyArray_Check(val)) { } else if (PyArray_Check(val)) {
PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val); PyArrayObject* array = PyArray_GETCONTIGUOUS((PyArrayObject*) val);
Array* data = obj->mutable_array_data(); Array* data = obj->mutable_array_data();
@ -365,6 +370,9 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjR
return PyString_FromStringAndSize(buffer, length); return PyString_FromStringAndSize(buffer, length);
} else if (obj.has_empty_data()) { } else if (obj.has_empty_data()) {
Py_RETURN_NONE; Py_RETURN_NONE;
} else if (obj.has_objref_data()) {
objrefs.push_back(obj.objref_data().data());
return make_pyobjref(worker_capsule, obj.objref_data().data());
} else if (obj.has_array_data()) { } else if (obj.has_array_data()) {
const Array& array = obj.array_data(); const Array& array = obj.array_data();
std::vector<npy_intp> dims; std::vector<npy_intp> dims;

View file

@ -60,11 +60,17 @@ class SerializationTest(unittest.TestCase):
ref1 = ray.push(0, w) ref1 = ray.push(0, w)
ref2 = ray.push(0, w) ref2 = ray.push(0, w)
ref3 = ray.push(0, w) ref3 = ray.push(0, w)
a = np.array([[ref0, ref1], [ref2, ref3]]) a = np.array([[ref0, ref1], [ref2, ref3]])
capsule, _ = serialization.serialize(w.handle, a) capsule, _ = serialization.serialize(w.handle, a)
result = serialization.deserialize(w.handle, capsule) result = serialization.deserialize(w.handle, capsule)
self.assertTrue((a == result).all()) self.assertTrue((a == result).all())
self.roundTripTest(w, ref0)
self.roundTripTest(w, [ref0, ref1, ref2, ref3])
self.roundTripTest(w, {'0': ref0, '1': ref1, '2': ref2, '3': ref3})
self.roundTripTest(w, (ref0, 1))
services.cleanup() services.cleanup()
class ObjStoreTest(unittest.TestCase): class ObjStoreTest(unittest.TestCase):