mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
from typing import List
|
|
import numpy as np
|
|
import arrays.single as single
|
|
import orchpy as op
|
|
|
|
BLOCK_SIZE = 10
|
|
|
|
class DistArray(object):
|
|
def construct(self, shape, dtype, objrefs):
|
|
self.shape = shape
|
|
self.dtype = dtype
|
|
self.objrefs = objrefs
|
|
self.ndim = len(shape)
|
|
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
|
|
if self.num_blocks != list(self.objrefs.shape):
|
|
raise Exception("The fields `num_blocks` and `objrefs` are inconsistent, `num_blocks` is {} and `objrefs` has shape {}".format(self.num_blocks, list(self.objrefs.shape)))
|
|
|
|
def deserialize(self, primitives):
|
|
(shape, dtype_name, objrefs) = primitives
|
|
self.construct(shape, np.dtype(dtype_name), objrefs)
|
|
|
|
def serialize(self):
|
|
return (self.shape, self.dtype.__name__, self.objrefs)
|
|
|
|
def __init__(self):
|
|
self.shape = None
|
|
self.dtype = None
|
|
self.objrefs = None
|
|
|
|
def compute_block_lower(self, index):
|
|
if len(index) != self.ndim:
|
|
raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim))
|
|
return [elem * BLOCK_SIZE for elem in index]
|
|
|
|
def compute_block_upper(self, index):
|
|
if len(index) != self.ndim:
|
|
raise Exception("The value `index` equals {}, but `ndim` is {}.".format(index, self.ndim))
|
|
upper = []
|
|
for i in range(self.ndim):
|
|
upper.append(min((index[i] + 1) * BLOCK_SIZE, self.shape[i]))
|
|
return upper
|
|
|
|
def compute_block_shape(self, index):
|
|
lower = self.compute_block_lower(index)
|
|
upper = self.compute_block_upper(index)
|
|
return [u - l for (l, u) in zip(lower, upper)]
|
|
|
|
def assemble(self):
|
|
"""Assemble an array on this node from a distributed array object reference."""
|
|
result = np.zeros(self.shape)
|
|
for index in np.ndindex(*self.num_blocks):
|
|
lower = self.compute_block_lower(index)
|
|
upper = self.compute_block_upper(index)
|
|
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = op.pull(self.objrefs[index])
|
|
return result
|
|
|
|
def __getitem__(self, sliced):
|
|
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
|
|
a = self.assemble()
|
|
return a[sliced]
|