2016-03-16 18:11:43 -07:00
from typing import List
import numpy as np
import arrays . single as single
import orchpy as op
2016-03-23 18:38:42 -07:00
__all__ = [ " BLOCK_SIZE " , " DistArray " , " assemble " , " zeros " , " ones " , " copy " ,
" eye " , " triu " , " tril " , " blockwise_dot " , " dot " , " block_column " , " block_row " ]
2016-03-16 18:11:43 -07:00
class DistArray ( object ) :
2016-03-23 18:38:42 -07:00
def construct ( self , shape , objrefs ) :
2016-03-16 18:11:43 -07:00
self . shape = shape
self . objrefs = objrefs
self . ndim = len ( shape )
self . num_blocks = [ int ( np . ceil ( 1.0 * a / BLOCK_SIZE ) ) for a in self . shape ]
if self . num_blocks != list ( self . objrefs . shape ) :
raise Exception ( " The fields `num_blocks` and `objrefs` are inconsistent, `num_blocks` is {} and `objrefs` has shape {} " . format ( self . num_blocks , list ( self . objrefs . shape ) ) )
def deserialize ( self , primitives ) :
2016-03-23 18:38:42 -07:00
( shape , objrefs ) = primitives
self . construct ( shape , objrefs )
2016-03-16 18:11:43 -07:00
def serialize ( self ) :
2016-03-23 18:38:42 -07:00
return ( self . shape , self . objrefs )
2016-03-16 18:11:43 -07:00
def __init__ ( self ) :
self . shape = None
self . objrefs = None
2016-03-23 18:38:42 -07:00
def compute_block_lower ( index , shape ) :
# TODO(rkn): Check that the entries of index are in the correct range.
# TODO(rkn): Check that len(index) == len(shape).
2016-03-16 18:11:43 -07:00
return [ elem * BLOCK_SIZE for elem in index ]
2016-03-23 18:38:42 -07:00
def compute_block_upper ( index , shape ) :
# TODO(rkn): Check that the entries of index are in the correct range.
# TODO(rkn): Check that len(index) == len(shape).
2016-03-16 18:11:43 -07:00
upper = [ ]
2016-03-23 18:38:42 -07:00
for i in range ( len ( shape ) ) :
upper . append ( min ( ( index [ i ] + 1 ) * BLOCK_SIZE , shape [ i ] ) )
2016-03-16 18:11:43 -07:00
return upper
2016-03-23 18:38:42 -07:00
def compute_block_shape ( index , shape ) :
lower = DistArray . compute_block_lower ( index , shape )
upper = DistArray . compute_block_upper ( index , shape )
2016-03-16 18:11:43 -07:00
return [ u - l for ( l , u ) in zip ( lower , upper ) ]
2016-03-23 18:38:42 -07:00
def compute_num_blocks ( shape ) :
return [ int ( np . ceil ( 1.0 * a / BLOCK_SIZE ) ) for a in shape ]
2016-03-16 18:11:43 -07:00
def assemble ( self ) :
""" Assemble an array on this node from a distributed array object reference. """
2016-03-23 18:38:42 -07:00
first_block = op . pull ( self . objrefs [ ( 0 , ) * self . ndim ] )
dtype = first_block . dtype
result = np . zeros ( self . shape , dtype = dtype )
2016-03-16 18:11:43 -07:00
for index in np . ndindex ( * self . num_blocks ) :
2016-03-23 18:38:42 -07:00
lower = DistArray . compute_block_lower ( index , self . shape )
upper = DistArray . compute_block_upper ( index , self . shape )
2016-03-16 18:11:43 -07:00
result [ [ slice ( l , u ) for ( l , u ) in zip ( lower , upper ) ] ] = op . pull ( self . objrefs [ index ] )
return result
def __getitem__ ( self , sliced ) :
# TODO(rkn): fix this, this is just a placeholder that should work but is inefficient
a = self . assemble ( )
return a [ sliced ]
2016-03-23 18:38:42 -07:00
@op.distributed ( [ DistArray ] , [ np . ndarray ] )
def assemble ( a ) :
return a . assemble ( )
@op.distributed ( [ List [ int ] , str ] , [ DistArray ] )
def zeros ( shape , dtype_name ) :
num_blocks = DistArray . compute_num_blocks ( shape )
objrefs = np . empty ( num_blocks , dtype = object )
for index in np . ndindex ( * num_blocks ) :
objrefs [ index ] = single . zeros ( DistArray . compute_block_shape ( index , shape ) , dtype_name )
result = DistArray ( )
result . construct ( shape , objrefs )
return result
@op.distributed ( [ List [ int ] , str ] , [ DistArray ] )
def ones ( shape , dtype_name ) :
num_blocks = DistArray . compute_num_blocks ( shape )
objrefs = np . empty ( num_blocks , dtype = object )
for index in np . ndindex ( * num_blocks ) :
objrefs [ index ] = single . ones ( DistArray . compute_block_shape ( index , shape ) , dtype_name )
result = DistArray ( )
result . construct ( shape , objrefs )
return result
@op.distributed ( [ DistArray ] , [ DistArray ] )
def copy ( a ) :
num_blocks = DistArray . compute_num_blocks ( a . shape )
objrefs = np . empty ( num_blocks , dtype = object )
for index in np . ndindex ( * num_blocks ) :
objrefs [ index ] = single . copy ( a . objrefs [ index ] )
result = DistArray ( )
result . construct ( a . shape , objrefs )
return result
@op.distributed ( [ int , str ] , [ DistArray ] )
def eye ( dim , dtype_name ) :
shape = [ dim , dim ]
num_blocks = DistArray . compute_num_blocks ( shape )
objrefs = np . empty ( num_blocks , dtype = object )
for ( i , j ) in np . ndindex ( * num_blocks ) :
if i == j :
objrefs [ i , j ] = single . eye ( DistArray . compute_block_shape ( [ i , j ] , shape ) [ 0 ] , dtype_name )
else :
objrefs [ i , j ] = single . zeros ( DistArray . compute_block_shape ( [ i , j ] , shape ) , dtype_name )
result = DistArray ( )
result . construct ( shape , objrefs )
return result
@op.distributed ( [ DistArray ] , [ DistArray ] )
def triu ( a ) :
if a . ndim != 2 :
raise Exception ( " Input must have 2 dimensions, but a.ndim is " + str ( a . ndim ) )
objrefs = np . empty ( a . num_blocks , dtype = object )
for i in range ( a . num_blocks [ 0 ] ) :
for j in range ( a . num_blocks [ 1 ] ) :
if i < j :
objrefs [ i , j ] = single . copy ( a . objrefs [ i , j ] )
elif i == j :
objrefs [ i , j ] = single . triu ( a . objrefs [ i , j ] )
else :
objrefs [ i , j ] = single . zeros_like ( a . objrefs [ i , j ] )
result = DistArray ( )
result . construct ( a . shape , objrefs )
return result
@op.distributed ( [ DistArray ] , [ DistArray ] )
def tril ( a ) :
if a . ndim != 2 :
raise Exception ( " Input must have 2 dimensions, but a.ndim is " + str ( a . ndim ) )
objrefs = np . empty ( a . num_blocks , dtype = object )
for i in range ( a . num_blocks [ 0 ] ) :
for j in range ( a . num_blocks [ 1 ] ) :
if i > j :
objrefs [ i , j ] = single . copy ( a . objrefs [ i , j ] )
elif i == j :
objrefs [ i , j ] = single . tril ( a . objrefs [ i , j ] )
else :
objrefs [ i , j ] = single . zeros_like ( a . objrefs [ i , j ] )
result = DistArray ( )
result . construct ( a . shape , objrefs )
return result
@op.distributed ( [ np . ndarray , None ] , [ np . ndarray ] )
def blockwise_dot ( * matrices ) :
n = len ( matrices )
if n % 2 != 0 :
raise Exception ( " blockwise_dot expects an even number of arguments, but len(matrices) is {} . " . format ( n ) )
shape = ( matrices [ 0 ] . shape [ 0 ] , matrices [ n / 2 ] . shape [ 1 ] )
result = np . zeros ( shape )
for i in range ( n / 2 ) :
result + = np . dot ( matrices [ i ] , matrices [ n / 2 + i ] )
return result
@op.distributed ( [ DistArray , DistArray ] , [ DistArray ] )
def dot ( a , b ) :
if a . ndim != 2 :
raise Exception ( " dot expects its arguments to be 2-dimensional, but a.ndim = {} . " . format ( a . ndim ) )
if b . ndim != 2 :
raise Exception ( " dot expects its arguments to be 2-dimensional, but b.ndim = {} . " . format ( b . ndim ) )
if a . shape [ 1 ] != b . shape [ 0 ] :
raise Exception ( " dot expects a.shape[1] to equal b.shape[0], but a.shape = {} and b.shape = {} . " . format ( a . shape , b . shape ) )
shape = [ a . shape [ 0 ] , b . shape [ 1 ] ]
num_blocks = DistArray . compute_num_blocks ( shape )
objrefs = np . empty ( num_blocks , dtype = object )
for i in range ( num_blocks [ 0 ] ) :
for j in range ( num_blocks [ 1 ] ) :
args = list ( a . objrefs [ i , : ] ) + list ( b . objrefs [ : , j ] )
objrefs [ i , j ] = blockwise_dot ( * args )
result = DistArray ( )
result . construct ( shape , objrefs )
return result
# This is not in numpy, should we expose this?
@op.distributed ( [ DistArray ] , [ DistArray ] )
def block_column ( a , col ) :
if a . ndim != 2 :
raise Exception ( " block_column expects its argument to be 2-dimensional, but a.ndim = {} , a.shape = {} . " . format ( a . ndim , a . shape ) )
top_block_shape = DistArray . compute_block_shape ( [ 0 , col ] )
shape = [ a . shape [ 0 ] , top_block_shape [ 1 ] ]
result = DistArray ( )
result . construct ( shape , a . objrefs [ : , col ] )
return result
# This is not in numpy, should we expose this?
@op.distributed ( [ DistArray ] , [ DistArray ] )
def block_row ( a , row ) :
if a . ndim != 2 :
raise Exception ( " block_row expects its argument to be 2-dimensional, but a.ndim = {} , a.shape = {} . " . format ( a . ndim , a . shape ) )
left_block_shape = DistArray . compute_block_shape ( [ row , 0 ] )
shape = [ left_block_shape [ 0 ] , a . shape [ 1 ] ]
result = DistArray ( )
result . construct ( shape , a . objrefs [ row , : ] )
return result