mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
18 lines
676 B
Python
18 lines
676 B
Python
import numpy as np
|
|
from typing import List
|
|
import ray
|
|
import ray.arrays.remote as ra
|
|
|
|
@ray.remote([List[ray.ObjRef]], [int])
|
|
def num_images(batches):
|
|
shape_refs = [ra.shape(batch) for batch in batches]
|
|
return sum([ray.pull(shape_ref)[0] for shape_ref in shape_refs])
|
|
|
|
@ray.remote([List[ray.ObjRef]], [np.ndarray])
|
|
def compute_mean_image(batches):
|
|
if len(batches) == 0:
|
|
raise Exception("No images were passed into `compute_mean_image`.")
|
|
sum_image_refs = [ra.sum(batch, axis=0) for batch in batches]
|
|
sum_images = [ray.pull(ref) for ref in sum_image_refs]
|
|
n_images = num_images(batches)
|
|
return np.sum(sum_images, axis=0).astype("float64") / ray.pull(n_images)
|