mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
implement bool serialization
This commit is contained in:
parent
37bd590656
commit
9a36e4208e
3 changed files with 23 additions and 1 deletions
|
@ -9,13 +9,17 @@ message String {
|
|||
}
|
||||
|
||||
message Double {
|
||||
double data = 1;
|
||||
double data = 1;
|
||||
}
|
||||
|
||||
// Empty used to represent a None object
|
||||
message Empty {
|
||||
}
|
||||
|
||||
message Bool {
|
||||
bool data = 1;
|
||||
}
|
||||
|
||||
message PyObj {
|
||||
bytes data = 1;
|
||||
}
|
||||
|
@ -25,6 +29,7 @@ message Obj {
|
|||
String string_data = 1;
|
||||
Int int_data = 2;
|
||||
Double double_data = 3;
|
||||
Bool bool_data = 10;
|
||||
Tuple tuple_data = 7;
|
||||
List list_data = 4;
|
||||
Dict dict_data = 8;
|
||||
|
|
|
@ -195,6 +195,13 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
|||
Double* data = obj->mutable_double_data();
|
||||
double d = PyFloat_AsDouble(val);
|
||||
data->set_data(d);
|
||||
} else if (PyBool_Check(val)) {
|
||||
Bool* data = obj->mutable_bool_data();
|
||||
if (val == Py_False) {
|
||||
data->set_data(false);
|
||||
} else {
|
||||
data->set_data(true);
|
||||
}
|
||||
} else if (PyTuple_Check(val)) {
|
||||
Tuple* data = obj->mutable_tuple_data();
|
||||
for (size_t i = 0, size = PyTuple_Size(val); i < size; ++i) {
|
||||
|
@ -321,6 +328,12 @@ PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjR
|
|||
return PyInt_FromLong(obj.int_data().data());
|
||||
} else if (obj.has_double_data()) {
|
||||
return PyFloat_FromDouble(obj.double_data().data());
|
||||
} else if (obj.has_bool_data()) {
|
||||
if (obj.bool_data().data()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
} else if (obj.has_tuple_data()) {
|
||||
const Tuple& data = obj.tuple_data();
|
||||
size_t size = data.elem_size();
|
||||
|
|
|
@ -41,6 +41,10 @@ class SerializationTest(unittest.TestCase):
|
|||
self.roundTripTest(w, None)
|
||||
self.roundTripTest(w, (None, None))
|
||||
self.roundTripTest(w, ("hello", None))
|
||||
self.roundTripTest(w, True)
|
||||
self.roundTripTest(w, False)
|
||||
self.roundTripTest(w, (True, False))
|
||||
self.roundTripTest(w, {True: "hello", False: "world"})
|
||||
|
||||
self.roundTripTest(w, {"hello" : "world", 1: 42, 1.0: 45})
|
||||
self.roundTripTest(w, {})
|
||||
|
|
Loading…
Add table
Reference in a new issue