fix ordering when calculating ensemble mean

This commit is contained in:
Valentin Boettcher 2022-04-13 20:26:13 +02:00
parent e2e4643d03
commit 8971ce4363

View file

@ -1,5 +1,6 @@
"""Utilities for the energy flow calculation."""
from __future__ import annotations
import itertools
import multiprocessing
import numpy as np
@ -32,11 +33,11 @@ EnsembleReturn = Union[Aggregate, list[Aggregate]]
class EnsembleValue:
def __init__(self, value: Union[Aggregate, list[Aggregate]]):
self._value: list[Aggregate] = (
self._value: list[Aggregate] = ( # type:ignore
value
if (isinstance(value, list) or isinstance(value, np.ndarray))
else [value]
) # type:ignore
)
@property
def final_aggregate(self):
@ -76,8 +77,8 @@ class EnsembleValue:
for agg in self._value:
yield EnsembleValue(agg)
def __getitem__(self, index: int):
return self._value[index]
def __getitem__(self, index):
return EnsembleValue(self._value[index])
def __len__(self) -> int:
return len(self._value)
@ -103,7 +104,7 @@ class EnsembleValue:
return EnsembleValue(out)
def __add__(self, other):
def __add__(self, other: Any) -> EnsembleValue:
if type(self) == type(other):
if len(self) != len(other):
raise RuntimeError("Can only add values of equal length.")
@ -654,16 +655,15 @@ def ensemble_mean(
_grouper(
chunk_size, itertools.islice(arg_iter, None, N - 1 if N else None)
),
total=int((N - 1 if N else None) / chunk_size + 1),
total=int((N - 1) / chunk_size + 1) if N is not None else None,
desc="Loading",
)
]
progress = tqdm(total=len(handles), desc="Processing")
while len(handles):
done_id, handles = ray.wait(handles, fetch_local=True)
res_chunk = np.array(ray.get(done_id[0]))
for ref in handles:
res_chunk = np.array(ray.get(ref))
for res in res_chunk:
aggregate.update(res)
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N: