mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
python dict serialization
This commit is contained in:
parent
0a7e606279
commit
16d91af7b8
3 changed files with 36 additions and 0 deletions
|
@ -23,6 +23,7 @@ message Obj {
|
|||
Double double_data = 3;
|
||||
Tuple tuple_data = 7;
|
||||
List list_data = 4;
|
||||
Dict dict_data = 8;
|
||||
Array array_data = 5;
|
||||
PyObj pyobj_data = 6;
|
||||
}
|
||||
|
@ -35,6 +36,15 @@ message Tuple {
|
|||
repeated Obj elem = 1;
|
||||
}
|
||||
|
||||
message DictEntry {
|
||||
Obj key = 1;
|
||||
Obj value = 2;
|
||||
}
|
||||
|
||||
message Dict {
|
||||
repeated DictEntry elem = 1;
|
||||
}
|
||||
|
||||
message Value {
|
||||
uint64 ref = 1; // for pass by reference
|
||||
Obj obj = 2; // for pass by value
|
||||
|
|
|
@ -179,6 +179,21 @@ int serialize(PyObject* val, Obj* obj) {
|
|||
return -1;
|
||||
}
|
||||
}
|
||||
} else if (PyDict_Check(val)) {
|
||||
PyObject *pykey, *pyvalue;
|
||||
Py_ssize_t pos = 0;
|
||||
Dict* data = obj->mutable_dict_data();
|
||||
while (PyDict_Next(val, &pos, &pykey, &pyvalue)) {
|
||||
DictEntry* elem = data->add_elem();
|
||||
Obj* key = elem->mutable_key();
|
||||
if (serialize(pykey, key) != 0) {
|
||||
return -1;
|
||||
}
|
||||
Obj* value = elem->mutable_value();
|
||||
if (serialize(pyvalue, value) != 0) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} else if (PyString_Check(val)) {
|
||||
char* buffer;
|
||||
Py_ssize_t length;
|
||||
|
@ -271,6 +286,14 @@ PyObject* deserialize(const Obj& obj) {
|
|||
PyList_SetItem(list, i, deserialize(data.elem(i)));
|
||||
}
|
||||
return list;
|
||||
} else if (obj.has_dict_data()) {
|
||||
const Dict& data = obj.dict_data();
|
||||
PyObject* dict = PyDict_New();
|
||||
size_t size = data.elem_size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
PyDict_SetItem(dict, deserialize(data.elem(i).key()), deserialize(data.elem(i).value()));
|
||||
}
|
||||
return dict;
|
||||
} else if (obj.has_string_data()) {
|
||||
const char* buffer = obj.string_data().data().data();
|
||||
Py_ssize_t length = obj.string_data().data().size();
|
||||
|
|
|
@ -68,6 +68,9 @@ class SerializationTest(unittest.TestCase):
|
|||
self.roundTripTest(42.0)
|
||||
self.roundTripTest((1.0, "hi"))
|
||||
|
||||
self.roundTripTest({"hello" : "world", 1: 42, 1.0: 45})
|
||||
self.roundTripTest({})
|
||||
|
||||
a = np.zeros((100, 100))
|
||||
res = serialization.serialize(a)
|
||||
b = serialization.deserialize(res)
|
||||
|
|
Loading…
Add table
Reference in a new issue