From f9e13e0e1b14985222280e13d090ada39fcddd54 Mon Sep 17 00:00:00 2001 From: Valentin Boettcher Date: Wed, 30 Nov 2022 12:17:01 -0500 Subject: [PATCH] add tracking to welford --- hopsflow/util.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/hopsflow/util.py b/hopsflow/util.py index bf29db3..f358cbd 100644 --- a/hopsflow/util.py +++ b/hopsflow/util.py @@ -674,20 +674,39 @@ def integrate_array( class WelfordAggregator: - __slots__ = ["n", "mean", "_m_2"] + __slots__ = ["n", "mean", "_m_2", "_tracker"] - def __init__(self, first_value: np.ndarray): + def __init__(self, first_value: np.ndarray, track=False): self.n = 1 self.mean = first_value self._m_2 = np.zeros_like(first_value) - def update(self, new_value: np.ndarray): + self._tracker: Optional[SortedList] = None + if track: + self._tracker = SortedList() + + def update(self, new_value: np.ndarray, i: Optional[int] = None): + if self._tracker is not None: + if i is None: + raise ValueError("Tracking is enabled but no index was supplied.") + + if self.has_sample(i): + return + + self._tracker.add(i) + self.n += 1 delta = new_value - self.mean self.mean += delta / self.n delta2 = new_value - self.mean self._m_2 += np.abs(delta) * np.abs(delta2) + def has_sample(self, i: int) -> bool: + if self._tracker is None: + return False # don't know + + return i in self._tracker + @property def sample_variance(self) -> np.ndarray: if self.n == 1: @@ -778,9 +797,7 @@ def load_online_cache(save: str): def ensemble_mean_online( - args: Any, - save: str, - function: Callable[..., np.ndarray], + args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None ) -> Optional[EnsembleValue]: path = get_online_data_path(save) @@ -796,13 +813,13 @@ def ensemble_mean_online( with path.open("rb") as agg_file: aggregate = pickle.load(agg_file) if result is not None: - aggregate.update(result) + aggregate.update(result, i) else: if result is None: raise RuntimeError("No cache and no result.") - aggregate = WelfordAggregator(result) + aggregate = WelfordAggregator(result, i is not None) with path.open("wb") as agg_file: pickle.dump(aggregate, agg_file)