mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Changed how ray treats deserialization of custom classes (#333)
This commit is contained in:
parent
98a508d6ca
commit
97b923a750
8 changed files with 20 additions and 29 deletions
|
@ -68,11 +68,11 @@ class ExampleClass(object):
|
|||
# This example assumes that field1 and field2 are serializable types.
|
||||
self.field1 = field1
|
||||
self.field2 = field2
|
||||
|
||||
def deserialize(self, primitives):
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
(field1, field2) = primitives
|
||||
self.field1 = field1
|
||||
self.field2 = field2
|
||||
return ExampleClass(field1, field2)
|
||||
|
||||
def serialize(self):
|
||||
return (self.field1, self.field2)
|
||||
|
|
|
@ -9,7 +9,7 @@ __all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy",
|
|||
BLOCK_SIZE = 10
|
||||
|
||||
class DistArray(object):
|
||||
def construct(self, shape, objectids=None):
|
||||
def __init__(self, shape, objectids=None):
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
|
||||
|
@ -17,17 +17,14 @@ class DistArray(object):
|
|||
if self.num_blocks != list(self.objectids.shape):
|
||||
raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape)))
|
||||
|
||||
def deserialize(self, primitives):
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
(shape, objectids) = primitives
|
||||
self.construct(shape, objectids)
|
||||
return DistArray(shape, objectids)
|
||||
|
||||
def serialize(self):
|
||||
return (self.shape, self.objectids)
|
||||
|
||||
def __init__(self, shape=None):
|
||||
if shape is not None:
|
||||
self.construct(shape)
|
||||
|
||||
@staticmethod
|
||||
def compute_block_lower(index, shape):
|
||||
if len(index) != len(shape):
|
||||
|
|
|
@ -46,8 +46,6 @@ def tsqr(a):
|
|||
current_rs = new_rs
|
||||
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
|
||||
|
||||
q_result = DistArray()
|
||||
|
||||
# handle the special case in which the whole DistArray "a" fits in one block
|
||||
# and has fewer rows than columns, this is a bit ugly so think about how to
|
||||
# remove it
|
||||
|
@ -56,9 +54,8 @@ def tsqr(a):
|
|||
else:
|
||||
q_shape = [a.shape[0], a.shape[0]]
|
||||
q_num_blocks = DistArray.compute_num_blocks(q_shape)
|
||||
q_result = DistArray()
|
||||
q_objectids = np.empty(q_num_blocks, dtype=object)
|
||||
q_result.construct(q_shape, q_objectids)
|
||||
q_result = DistArray(q_shape, q_objectids)
|
||||
|
||||
# reconstruct output
|
||||
for i in range(num_blocks):
|
||||
|
@ -145,8 +142,7 @@ def qr(a):
|
|||
k = min(m, n)
|
||||
|
||||
# we will store our scratch work in a_work
|
||||
a_work = DistArray()
|
||||
a_work.construct(a.shape, np.copy(a.objectids))
|
||||
a_work = DistArray(a.shape, np.copy(a.objectids))
|
||||
|
||||
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
|
||||
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
|
|
|
@ -12,6 +12,5 @@ def normal(shape):
|
|||
objectids = np.empty(num_blocks, dtype=object)
|
||||
for index in np.ndindex(*num_blocks):
|
||||
objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
|
||||
result = DistArray()
|
||||
result.construct(shape, objectids)
|
||||
result = DistArray(shape, objectids)
|
||||
return result
|
||||
|
|
|
@ -46,8 +46,7 @@ def from_primitive(primitive_obj):
|
|||
# This code assumes that the type module.__dict__[type_name] knows how to deserialize itself
|
||||
type_module, type_name = primitive_obj[0]
|
||||
module = importlib.import_module(type_module)
|
||||
obj = module.__dict__[type_name]()
|
||||
obj.deserialize(primitive_obj[1])
|
||||
obj = module.__dict__[type_name].deserialize(primitive_obj[1])
|
||||
return obj
|
||||
|
||||
def is_arrow_serializable(value):
|
||||
|
|
|
@ -30,7 +30,7 @@ class RayFailedObject(object):
|
|||
error_message (str): The error message raised by the task that failed.
|
||||
"""
|
||||
|
||||
def __init__(self, error_message=None):
|
||||
def __init__(self, error_message):
|
||||
"""Initialize a RayFailedObject.
|
||||
|
||||
Args:
|
||||
|
@ -39,7 +39,8 @@ class RayFailedObject(object):
|
|||
"""
|
||||
self.error_message = error_message
|
||||
|
||||
def deserialize(self, primitives):
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayFailedObject from a primitive object.
|
||||
|
||||
This initializes a RayFailedObject from a primitive object created by the
|
||||
|
@ -52,7 +53,7 @@ class RayFailedObject(object):
|
|||
Args:
|
||||
primitives (str): The object's error message.
|
||||
"""
|
||||
self.error_message = primitives
|
||||
return RayFailedObject(primitives)
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayFailedObject into a primitive object.
|
||||
|
|
|
@ -49,8 +49,7 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
reload(module)
|
||||
ray.init(start_ray_local=True, num_workers=0)
|
||||
|
||||
x = da.DistArray()
|
||||
x.construct([2, 3, 4], np.array([[[ray.put(0)]]]))
|
||||
x = da.DistArray([2, 3, 4], np.array([[[ray.put(0)]]]))
|
||||
capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x)
|
||||
y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule)
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
|
@ -65,8 +64,7 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
|
||||
a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
x = da.DistArray()
|
||||
x.construct([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
|
||||
x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
|
||||
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])))
|
||||
|
||||
ray.services.cleanup()
|
||||
|
|
|
@ -23,7 +23,8 @@ class UserDefinedType(object):
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
def deserialize(self, primitives):
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
return "user defined type"
|
||||
|
||||
def serialize(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue