implementing saving snapshots every N samples for online analysis

This commit is contained in:
Valentin Boettcher 2022-12-09 15:04:51 -05:00
parent 8ede222312
commit cd5f82d044
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -117,7 +117,7 @@ class EnsembleValue:
return EnsembleValue([(N, val[min_index].copy(), σ[min_index].copy())])
def __getitem__(self, index):
return EnsembleValue(self._value[index])
return EnsembleValue([self._value[index]])
def slice(self, slc: Union[np.ndarray, slice]):
results = []
@ -709,7 +709,6 @@ class WelfordAggregator:
save["tracker"] = self._tracker
with open(path, "wb") as f:
portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE)
portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE)
np.savez(f, **save)
portalocker.unlock(f)
@ -837,7 +836,11 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
def ensemble_mean_online(
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
args: Any,
save: str,
function: Callable[..., np.ndarray],
i: Optional[int] = None,
every: Optional[int] = None,
) -> Optional[EnsembleValue]:
path = Path(save)
@ -860,9 +863,29 @@ def ensemble_mean_online(
aggregate = WelfordAggregator(result, i)
aggregate.dump(str(path))
if every is not None and aggregate.n % every == 0:
path.with_stem(f"{path.stem}_{aggregate.n}")
aggregate.dump(str(path))
return aggregate.ensemble_value
def get_online_values_from_cache(path):
path = Path(path)
all_versions = list(path.parent.glob(path.stem + "*" + path.suffix))
final = all_versions[0]
all_versions = all_versions[1:] + [final]
vals = []
for path in all_versions:
agg = WelfordAggregator.from_dump(str(path))
vals.append([agg.n, agg.mean, agg.ensemble_std])
return EnsembleValue(vals)
def ensemble_mean(
arg_iter: Iterator[Any],
function: Callable[..., np.ndarray],