diff --git a/hopsflow/util.py b/hopsflow/util.py index 4999f29..e751e69 100644 --- a/hopsflow/util.py +++ b/hopsflow/util.py @@ -117,7 +117,7 @@ class EnsembleValue: return EnsembleValue([(N, val[min_index].copy(), σ[min_index].copy())]) def __getitem__(self, index): - return EnsembleValue(self._value[index]) + return EnsembleValue([self._value[index]]) def slice(self, slc: Union[np.ndarray, slice]): results = [] @@ -709,7 +709,6 @@ class WelfordAggregator: save["tracker"] = self._tracker with open(path, "wb") as f: - portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE) portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE) np.savez(f, **save) portalocker.unlock(f) @@ -837,7 +836,11 @@ def _ensemble_remote_function(function, chunk: tuple, index: int): def ensemble_mean_online( - args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None + args: Any, + save: str, + function: Callable[..., np.ndarray], + i: Optional[int] = None, + every: Optional[int] = None, ) -> Optional[EnsembleValue]: path = Path(save) @@ -860,9 +863,29 @@ def ensemble_mean_online( aggregate = WelfordAggregator(result, i) aggregate.dump(str(path)) + if every is not None and aggregate.n % every == 0: + path.with_stem(f"{path.stem}_{aggregate.n}") + aggregate.dump(str(path)) + return aggregate.ensemble_value +def get_online_values_from_cache(path): + path = Path(path) + all_versions = list(path.parent.glob(path.stem + "*" + path.suffix)) + + final = all_versions[0] + all_versions = all_versions[1:] + [final] + + vals = [] + + for path in all_versions: + agg = WelfordAggregator.from_dump(str(path)) + vals.append([agg.n, agg.mean, agg.ensemble_std]) + + return EnsembleValue(vals) + + def ensemble_mean( arg_iter: Iterator[Any], function: Callable[..., np.ndarray],