mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Fix bug in serializing arguments of tasks that are more complex objects (#72)
* Give more informative error message when we do not know how to serialize a class. * Check that passing arguments to remote functions and getting them does not change their values. * fix serialization bug * fix tests for common module * Formatting. * Bug fix in init_pickle_module signature. * Use pickle with HIGHEST_PROTOCOL.
This commit is contained in:
parent
1499834be1
commit
58e8bbcb34
6 changed files with 66 additions and 10 deletions
|
@ -91,7 +91,7 @@ def serialize(obj):
|
||||||
"""
|
"""
|
||||||
class_id = class_identifier(type(obj))
|
class_id = class_identifier(type(obj))
|
||||||
if class_id not in whitelisted_classes:
|
if class_id not in whitelisted_classes:
|
||||||
raise Exception("Ray does not know how to serialize the object {}. To fix this, call 'ray.register_class' on the class of the object.".format(obj))
|
raise Exception("Ray does not know how to serialize objects of type {}. To fix this, call 'ray.register_class' with this class.".format(type(obj)))
|
||||||
if class_id in classes_to_pickle:
|
if class_id in classes_to_pickle:
|
||||||
serialized_obj = {"data": pickling.dumps(obj)}
|
serialized_obj = {"data": pickling.dumps(obj)}
|
||||||
elif class_id in custom_serializers.keys():
|
elif class_id in custom_serializers.keys():
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
#include "node.h"
|
#include "node.h"
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
#include "common_extension.h"
|
#include "common_extension.h"
|
||||||
#include "task.h"
|
#include "task.h"
|
||||||
#include "utarray.h"
|
#include "utarray.h"
|
||||||
|
@ -8,7 +9,21 @@
|
||||||
|
|
||||||
PyObject *CommonError;
|
PyObject *CommonError;
|
||||||
|
|
||||||
#define MARSHAL_VERSION 2
|
/* Initialize pickle module. */
|
||||||
|
|
||||||
|
PyObject *pickle_module = NULL;
|
||||||
|
PyObject *pickle_loads = NULL;
|
||||||
|
PyObject *pickle_dumps = NULL;
|
||||||
|
PyObject *pickle_protocol = NULL;
|
||||||
|
|
||||||
|
void init_pickle_module(void) {
|
||||||
|
/* For Python 3 this needs to be "_pickle" instead of "cPickle". */
|
||||||
|
pickle_module = PyImport_ImportModuleNoBlock("cPickle");
|
||||||
|
pickle_loads = PyString_FromString("loads");
|
||||||
|
pickle_dumps = PyString_FromString("dumps");
|
||||||
|
pickle_protocol = PyObject_GetAttrString(pickle_module, "HIGHEST_PROTOCOL");
|
||||||
|
CHECK(pickle_module != NULL);
|
||||||
|
}
|
||||||
|
|
||||||
/* Define the PyObjectID class. */
|
/* Define the PyObjectID class. */
|
||||||
|
|
||||||
|
@ -194,7 +209,10 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) {
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
PyObject *arg = PyList_GetItem(arguments, i);
|
PyObject *arg = PyList_GetItem(arguments, i);
|
||||||
if (!PyObject_IsInstance(arg, (PyObject *) &PyObjectIDType)) {
|
if (!PyObject_IsInstance(arg, (PyObject *) &PyObjectIDType)) {
|
||||||
PyObject *data = PyMarshal_WriteObjectToString(arg, MARSHAL_VERSION);
|
CHECK(pickle_module != NULL);
|
||||||
|
CHECK(pickle_dumps != NULL);
|
||||||
|
PyObject *data = PyObject_CallMethodObjArgs(pickle_module, pickle_dumps,
|
||||||
|
arg, pickle_protocol, NULL);
|
||||||
value_data_bytes += PyString_Size(data);
|
value_data_bytes += PyString_Size(data);
|
||||||
utarray_push_back(val_repr_ptrs, &data);
|
utarray_push_back(val_repr_ptrs, &data);
|
||||||
}
|
}
|
||||||
|
@ -248,10 +266,15 @@ static PyObject *PyTask_arguments(PyObject *self) {
|
||||||
object_id object_id = task_arg_id(task, i);
|
object_id object_id = task_arg_id(task, i);
|
||||||
PyList_SetItem(arg_list, i, PyObjectID_make(object_id));
|
PyList_SetItem(arg_list, i, PyObjectID_make(object_id));
|
||||||
} else {
|
} else {
|
||||||
PyObject *s =
|
CHECK(pickle_module != NULL);
|
||||||
PyMarshal_ReadObjectFromString((char *) task_arg_val(task, i),
|
CHECK(pickle_loads != NULL);
|
||||||
(Py_ssize_t) task_arg_length(task, i));
|
PyObject *str =
|
||||||
PyList_SetItem(arg_list, i, s);
|
PyString_FromStringAndSize((char *) task_arg_val(task, i),
|
||||||
|
(Py_ssize_t) task_arg_length(task, i));
|
||||||
|
PyObject *val =
|
||||||
|
PyObject_CallMethodObjArgs(pickle_module, pickle_loads, str, NULL);
|
||||||
|
Py_XDECREF(str);
|
||||||
|
PyList_SetItem(arg_list, i, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return arg_list;
|
return arg_list;
|
||||||
|
|
|
@ -26,6 +26,13 @@ extern PyTypeObject PyObjectIDType;
|
||||||
|
|
||||||
extern PyTypeObject PyTaskType;
|
extern PyTypeObject PyTaskType;
|
||||||
|
|
||||||
|
/* Python module for pickling. */
|
||||||
|
extern PyObject *pickle_module;
|
||||||
|
extern PyObject *pickle_dumps;
|
||||||
|
extern PyObject *pickle_loads;
|
||||||
|
|
||||||
|
void init_pickle_module(void);
|
||||||
|
|
||||||
int PyObjectToUniqueID(PyObject *object, object_id *objectid);
|
int PyObjectToUniqueID(PyObject *object, object_id *objectid);
|
||||||
|
|
||||||
PyObject *PyObjectID_make(object_id object_id);
|
PyObject *PyObjectID_make(object_id object_id);
|
||||||
|
|
|
@ -24,6 +24,8 @@ PyMODINIT_FUNC initcommon(void) {
|
||||||
m = Py_InitModule3("common", common_methods,
|
m = Py_InitModule3("common", common_methods,
|
||||||
"A module for common types. This is used for testing.");
|
"A module for common types. This is used for testing.");
|
||||||
|
|
||||||
|
init_pickle_module();
|
||||||
|
|
||||||
Py_INCREF(&PyTaskType);
|
Py_INCREF(&PyTaskType);
|
||||||
PyModule_AddObject(m, "Task", (PyObject *) &PyTaskType);
|
PyModule_AddObject(m, "Task", (PyObject *) &PyTaskType);
|
||||||
|
|
||||||
|
|
|
@ -125,6 +125,8 @@ PyMODINIT_FUNC initlibphoton(void) {
|
||||||
m = Py_InitModule3("libphoton", photon_methods,
|
m = Py_InitModule3("libphoton", photon_methods,
|
||||||
"A module for the local scheduler.");
|
"A module for the local scheduler.");
|
||||||
|
|
||||||
|
init_pickle_module();
|
||||||
|
|
||||||
Py_INCREF(&PyTaskType);
|
Py_INCREF(&PyTaskType);
|
||||||
PyModule_AddObject(m, "Task", (PyObject *) &PyTaskType);
|
PyModule_AddObject(m, "Task", (PyObject *) &PyTaskType);
|
||||||
|
|
||||||
|
|
|
@ -50,11 +50,11 @@ PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 0L, 1L << 62, "a", string.printable, "\u262F",
|
||||||
np.array(["hi", 3], dtype=object),
|
np.array(["hi", 3], dtype=object),
|
||||||
np.array([["hi", u"hi"], [1.3, 1L]])]
|
np.array([["hi", u"hi"], [1.3, 1L]])]
|
||||||
|
|
||||||
COMPLEX_OBJECTS = [#[[[[[[[[[[[[]]]]]]]]]]]],
|
COMPLEX_OBJECTS = [[[[[[[[[[[[[]]]]]]]]]]]],
|
||||||
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
|
{"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)},
|
||||||
#{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}},
|
#{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}},
|
||||||
#((((((((((),),),),),),),),),),
|
((((((((((),),),),),),),),),),
|
||||||
#{"a": {"b": {"c": {"d": {}}}}}
|
{"a": {"b": {"c": {"d": {}}}}}
|
||||||
]
|
]
|
||||||
|
|
||||||
class Foo(object):
|
class Foo(object):
|
||||||
|
@ -144,6 +144,28 @@ class SerializationTest(unittest.TestCase):
|
||||||
|
|
||||||
ray.worker.cleanup()
|
ray.worker.cleanup()
|
||||||
|
|
||||||
|
def testPassingArgumentsByValue(self):
|
||||||
|
ray.init(start_ray_local=True, num_workers=1)
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def f(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
ray.register_class(Exception)
|
||||||
|
ray.register_class(CustomError)
|
||||||
|
ray.register_class(Point)
|
||||||
|
ray.register_class(Foo)
|
||||||
|
ray.register_class(Bar)
|
||||||
|
ray.register_class(Baz)
|
||||||
|
ray.register_class(NamedTupleExample)
|
||||||
|
|
||||||
|
# Check that we can pass arguments by value to remote functions and that
|
||||||
|
# they are uncorrupted.
|
||||||
|
for obj in RAY_TEST_OBJECTS:
|
||||||
|
assert_equal(obj, ray.get(f.remote(obj)))
|
||||||
|
|
||||||
|
ray.worker.cleanup()
|
||||||
|
|
||||||
class WorkerTest(unittest.TestCase):
|
class WorkerTest(unittest.TestCase):
|
||||||
|
|
||||||
def testPutGet(self):
|
def testPutGet(self):
|
||||||
|
|
Loading…
Add table
Reference in a new issue