ray/examples/imagenet/functions.py
2016-06-10 17:25:55 -07:00

18 lines
676 B
Python

import numpy as np
from typing import List
import ray
import ray.arrays.remote as ra
@ray.remote([List[ray.ObjRef]], [int])
def num_images(batches):
shape_refs = [ra.shape(batch) for batch in batches]
return sum([ray.pull(shape_ref)[0] for shape_ref in shape_refs])
@ray.remote([List[ray.ObjRef]], [np.ndarray])
def compute_mean_image(batches):
if len(batches) == 0:
raise Exception("No images were passed into `compute_mean_image`.")
sum_image_refs = [ra.sum(batch, axis=0) for batch in batches]
sum_images = [ray.pull(ref) for ref in sum_image_refs]
n_images = num_images(batches)
return np.sum(sum_images, axis=0).astype("float64") / ray.pull(n_images)