Changed how ray treats deserialization of custom classes (#333)

This commit is contained in:
Wapaul1 2016-08-01 15:38:05 -07:00 committed by Philipp Moritz
parent 98a508d6ca
commit 97b923a750
8 changed files with 20 additions and 29 deletions

View file

@ -68,11 +68,11 @@ class ExampleClass(object):
# This example assumes that field1 and field2 are serializable types.
self.field1 = field1
self.field2 = field2
def deserialize(self, primitives):
@staticmethod
def deserialize(primitives):
(field1, field2) = primitives
self.field1 = field1
self.field2 = field2
return ExampleClass(field1, field2)
def serialize(self):
return (self.field1, self.field2)

View file

@ -9,7 +9,7 @@ __all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
BLOCK_SIZE = 10
class DistArray(object):
def construct(self, shape, objectids=None):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
@ -17,17 +17,14 @@ class DistArray(object):
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
def deserialize(self, primitives):
@staticmethod
def deserialize(primitives):
(shape, objectids) = primitives
self.construct(shape, objectids)
return DistArray(shape, objectids)
def serialize(self):
return (self.shape, self.objectids)
def __init__(self, shape=None):
if shape is not None:
self.construct(shape)
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):

View file

@ -46,8 +46,6 @@ def tsqr(a):
current_rs = new_rs
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
q_result = DistArray()
# handle the special case in which the whole DistArray "a" fits in one block
# and has fewer rows than columns, this is a bit ugly so think about how to
# remove it
@ -56,9 +54,8 @@ def tsqr(a):
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = DistArray.compute_num_blocks(q_shape)
q_result = DistArray()
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result.construct(q_shape, q_objectids)
q_result = DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
@ -145,8 +142,7 @@ def qr(a):
k = min(m, n)
# we will store our scratch work in a_work
a_work = DistArray()
a_work.construct(a.shape, np.copy(a.objectids))
a_work = DistArray(a.shape, np.copy(a.objectids))
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.

View file

@ -12,6 +12,5 @@ def normal(shape):
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
result = DistArray()
result.construct(shape, objectids)
result = DistArray(shape, objectids)
return result

View file

@ -46,8 +46,7 @@ def from_primitive(primitive_obj):
# This code assumes that the type module.__dict__[type_name] knows how to deserialize itself
type_module, type_name = primitive_obj[0]
module = importlib.import_module(type_module)
obj = module.__dict__[type_name]()
obj.deserialize(primitive_obj[1])
obj = module.__dict__[type_name].deserialize(primitive_obj[1])
return obj
def is_arrow_serializable(value):

View file

@ -30,7 +30,7 @@ class RayFailedObject(object):
error_message (str): The error message raised by the task that failed.
"""
def __init__(self, error_message=None):
def __init__(self, error_message):
"""Initialize a RayFailedObject.
Args:
@ -39,7 +39,8 @@ class RayFailedObject(object):
"""
self.error_message = error_message
def deserialize(self, primitives):
@staticmethod
def deserialize(primitives):
"""Create a RayFailedObject from a primitive object.
This initializes a RayFailedObject from a primitive object created by the
@ -52,7 +53,7 @@ class RayFailedObject(object):
Args:
primitives (str): The object's error message.
"""
self.error_message = primitives
return RayFailedObject(primitives)
def serialize(self):
"""Turn a RayFailedObject into a primitive object.

View file

@ -49,8 +49,7 @@ class DistributedArrayTest(unittest.TestCase):
reload(module)
ray.init(start_ray_local=True, num_workers=0)
x = da.DistArray()
x.construct([2, 3, 4], np.array([[[ray.put(0)]]]))
x = da.DistArray([2, 3, 4], np.array([[[ray.put(0)]]]))
capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x)
y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule)
self.assertEqual(x.shape, y.shape)
@ -65,8 +64,7 @@ class DistributedArrayTest(unittest.TestCase):
a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
x = da.DistArray()
x.construct([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])))
ray.services.cleanup()

View file

@ -23,7 +23,8 @@ class UserDefinedType(object):
def __init__(self):
pass
def deserialize(self, primitives):
@staticmethod
def deserialize(primitives):
return "user defined type"
def serialize(self):