mirror of
https://github.com/vale981/hopsflow
synced 2025-03-04 16:31:38 -05:00
fix ordering when calculating ensemble mean
This commit is contained in:
parent
e2e4643d03
commit
8971ce4363
1 changed files with 9 additions and 9 deletions
|
@ -1,5 +1,6 @@
|
|||
"""Utilities for the energy flow calculation."""
|
||||
|
||||
from __future__ import annotations
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
|
@ -32,11 +33,11 @@ EnsembleReturn = Union[Aggregate, list[Aggregate]]
|
|||
|
||||
class EnsembleValue:
|
||||
def __init__(self, value: Union[Aggregate, list[Aggregate]]):
|
||||
self._value: list[Aggregate] = (
|
||||
self._value: list[Aggregate] = ( # type:ignore
|
||||
value
|
||||
if (isinstance(value, list) or isinstance(value, np.ndarray))
|
||||
else [value]
|
||||
) # type:ignore
|
||||
)
|
||||
|
||||
@property
|
||||
def final_aggregate(self):
|
||||
|
@ -76,8 +77,8 @@ class EnsembleValue:
|
|||
for agg in self._value:
|
||||
yield EnsembleValue(agg)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self._value[index]
|
||||
def __getitem__(self, index):
|
||||
return EnsembleValue(self._value[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._value)
|
||||
|
@ -103,7 +104,7 @@ class EnsembleValue:
|
|||
|
||||
return EnsembleValue(out)
|
||||
|
||||
def __add__(self, other):
|
||||
def __add__(self, other: Any) -> EnsembleValue:
|
||||
if type(self) == type(other):
|
||||
if len(self) != len(other):
|
||||
raise RuntimeError("Can only add values of equal length.")
|
||||
|
@ -654,16 +655,15 @@ def ensemble_mean(
|
|||
_grouper(
|
||||
chunk_size, itertools.islice(arg_iter, None, N - 1 if N else None)
|
||||
),
|
||||
total=int((N - 1 if N else None) / chunk_size + 1),
|
||||
total=int((N - 1) / chunk_size + 1) if N is not None else None,
|
||||
desc="Loading",
|
||||
)
|
||||
]
|
||||
|
||||
progress = tqdm(total=len(handles), desc="Processing")
|
||||
|
||||
while len(handles):
|
||||
done_id, handles = ray.wait(handles, fetch_local=True)
|
||||
res_chunk = np.array(ray.get(done_id[0]))
|
||||
for ref in handles:
|
||||
res_chunk = np.array(ray.get(ref))
|
||||
for res in res_chunk:
|
||||
aggregate.update(res)
|
||||
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
|
||||
|
|
Loading…
Add table
Reference in a new issue