diff --git a/doc/remote-functions.md b/doc/remote-functions.md index bfae67341..e8bdc3a02 100644 --- a/doc/remote-functions.md +++ b/doc/remote-functions.md @@ -68,11 +68,11 @@ class ExampleClass(object): # This example assumes that field1 and field2 are serializable types. self.field1 = field1 self.field2 = field2 - - def deserialize(self, primitives): + + @staticmethod + def deserialize(primitives): (field1, field2) = primitives - self.field1 = field1 - self.field2 = field2 + return ExampleClass(field1, field2) def serialize(self): return (self.field1, self.field2) diff --git a/lib/python/ray/array/distributed/core.py b/lib/python/ray/array/distributed/core.py index 756464c20..d9fa158dc 100644 --- a/lib/python/ray/array/distributed/core.py +++ b/lib/python/ray/array/distributed/core.py @@ -9,7 +9,7 @@ __all__ = ["BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones", "copy", BLOCK_SIZE = 10 class DistArray(object): - def construct(self, shape, objectids=None): + def __init__(self, shape, objectids=None): self.shape = shape self.ndim = len(shape) self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape] @@ -17,17 +17,14 @@ class DistArray(object): if self.num_blocks != list(self.objectids.shape): raise Exception("The fields `num_blocks` and `objectids` are inconsistent, `num_blocks` is {} and `objectids` has shape {}".format(self.num_blocks, list(self.objectids.shape))) - def deserialize(self, primitives): + @staticmethod + def deserialize(primitives): (shape, objectids) = primitives - self.construct(shape, objectids) + return DistArray(shape, objectids) def serialize(self): return (self.shape, self.objectids) - def __init__(self, shape=None): - if shape is not None: - self.construct(shape) - @staticmethod def compute_block_lower(index, shape): if len(index) != len(shape): diff --git a/lib/python/ray/array/distributed/linalg.py b/lib/python/ray/array/distributed/linalg.py index efbe206a9..11c8702ef 100644 --- a/lib/python/ray/array/distributed/linalg.py +++ b/lib/python/ray/array/distributed/linalg.py @@ -46,8 +46,6 @@ def tsqr(a): current_rs = new_rs assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs)) - q_result = DistArray() - # handle the special case in which the whole DistArray "a" fits in one block # and has fewer rows than columns, this is a bit ugly so think about how to # remove it @@ -56,9 +54,8 @@ def tsqr(a): else: q_shape = [a.shape[0], a.shape[0]] q_num_blocks = DistArray.compute_num_blocks(q_shape) - q_result = DistArray() q_objectids = np.empty(q_num_blocks, dtype=object) - q_result.construct(q_shape, q_objectids) + q_result = DistArray(q_shape, q_objectids) # reconstruct output for i in range(num_blocks): @@ -145,8 +142,7 @@ def qr(a): k = min(m, n) # we will store our scratch work in a_work - a_work = DistArray() - a_work.construct(a.shape, np.copy(a.objectids)) + a_work = DistArray(a.shape, np.copy(a.objectids)) result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it. diff --git a/lib/python/ray/array/distributed/random.py b/lib/python/ray/array/distributed/random.py index f966162be..951ade7df 100644 --- a/lib/python/ray/array/distributed/random.py +++ b/lib/python/ray/array/distributed/random.py @@ -12,6 +12,5 @@ def normal(shape): objectids = np.empty(num_blocks, dtype=object) for index in np.ndindex(*num_blocks): objectids[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape)) - result = DistArray() - result.construct(shape, objectids) + result = DistArray(shape, objectids) return result diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index b62124c88..4746412f2 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -46,8 +46,7 @@ def from_primitive(primitive_obj): # This code assumes that the type module.__dict__[type_name] knows how to deserialize itself type_module, type_name = primitive_obj[0] module = importlib.import_module(type_module) - obj = module.__dict__[type_name]() - obj.deserialize(primitive_obj[1]) + obj = module.__dict__[type_name].deserialize(primitive_obj[1]) return obj def is_arrow_serializable(value): diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index a9b1d10ed..536048226 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -30,7 +30,7 @@ class RayFailedObject(object): error_message (str): The error message raised by the task that failed. """ - def __init__(self, error_message=None): + def __init__(self, error_message): """Initialize a RayFailedObject. Args: @@ -39,7 +39,8 @@ class RayFailedObject(object): """ self.error_message = error_message - def deserialize(self, primitives): + @staticmethod + def deserialize(primitives): """Create a RayFailedObject from a primitive object. This initializes a RayFailedObject from a primitive object created by the @@ -52,7 +53,7 @@ class RayFailedObject(object): Args: primitives (str): The object's error message. """ - self.error_message = primitives + return RayFailedObject(primitives) def serialize(self): """Turn a RayFailedObject into a primitive object. diff --git a/test/array_test.py b/test/array_test.py index ba5f65fd8..442a31707 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -49,8 +49,7 @@ class DistributedArrayTest(unittest.TestCase): reload(module) ray.init(start_ray_local=True, num_workers=0) - x = da.DistArray() - x.construct([2, 3, 4], np.array([[[ray.put(0)]]])) + x = da.DistArray([2, 3, 4], np.array([[[ray.put(0)]]])) capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x) y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule) self.assertEqual(x.shape, y.shape) @@ -65,8 +64,7 @@ class DistributedArrayTest(unittest.TestCase): a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) - x = da.DistArray() - x.construct([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]])) + x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]])) self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])]))) ray.services.cleanup() diff --git a/test/runtest.py b/test/runtest.py index af9008257..3c46d5f4d 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -23,7 +23,8 @@ class UserDefinedType(object): def __init__(self): pass - def deserialize(self, primitives): + @staticmethod + def deserialize(primitives): return "user defined type" def serialize(self):