mirror of
https://github.com/vale981/hopsflow
synced 2025-03-05 08:51:37 -05:00
implementing saving snapshots every N samples for online analysis
This commit is contained in:
parent
8ede222312
commit
cd5f82d044
1 changed files with 26 additions and 3 deletions
|
@ -117,7 +117,7 @@ class EnsembleValue:
|
|||
return EnsembleValue([(N, val[min_index].copy(), σ[min_index].copy())])
|
||||
|
||||
def __getitem__(self, index):
|
||||
return EnsembleValue(self._value[index])
|
||||
return EnsembleValue([self._value[index]])
|
||||
|
||||
def slice(self, slc: Union[np.ndarray, slice]):
|
||||
results = []
|
||||
|
@ -709,7 +709,6 @@ class WelfordAggregator:
|
|||
save["tracker"] = self._tracker
|
||||
|
||||
with open(path, "wb") as f:
|
||||
portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE)
|
||||
portalocker.lock(f, portalocker.LockFlags.EXCLUSIVE)
|
||||
np.savez(f, **save)
|
||||
portalocker.unlock(f)
|
||||
|
@ -837,7 +836,11 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
|
|||
|
||||
|
||||
def ensemble_mean_online(
|
||||
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
|
||||
args: Any,
|
||||
save: str,
|
||||
function: Callable[..., np.ndarray],
|
||||
i: Optional[int] = None,
|
||||
every: Optional[int] = None,
|
||||
) -> Optional[EnsembleValue]:
|
||||
path = Path(save)
|
||||
|
||||
|
@ -860,9 +863,29 @@ def ensemble_mean_online(
|
|||
aggregate = WelfordAggregator(result, i)
|
||||
|
||||
aggregate.dump(str(path))
|
||||
if every is not None and aggregate.n % every == 0:
|
||||
path.with_stem(f"{path.stem}_{aggregate.n}")
|
||||
aggregate.dump(str(path))
|
||||
|
||||
return aggregate.ensemble_value
|
||||
|
||||
|
||||
def get_online_values_from_cache(path):
|
||||
path = Path(path)
|
||||
all_versions = list(path.parent.glob(path.stem + "*" + path.suffix))
|
||||
|
||||
final = all_versions[0]
|
||||
all_versions = all_versions[1:] + [final]
|
||||
|
||||
vals = []
|
||||
|
||||
for path in all_versions:
|
||||
agg = WelfordAggregator.from_dump(str(path))
|
||||
vals.append([agg.n, agg.mean, agg.ensemble_std])
|
||||
|
||||
return EnsembleValue(vals)
|
||||
|
||||
|
||||
def ensemble_mean(
|
||||
arg_iter: Iterator[Any],
|
||||
function: Callable[..., np.ndarray],
|
||||
|
|
Loading…
Add table
Reference in a new issue