add tracking to welford

This commit is contained in:
Valentin Boettcher 2022-11-30 12:17:01 -05:00
parent d75ca67126
commit f9e13e0e1b
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -674,20 +674,39 @@ def integrate_array(
class WelfordAggregator:
__slots__ = ["n", "mean", "_m_2"]
__slots__ = ["n", "mean", "_m_2", "_tracker"]
def __init__(self, first_value: np.ndarray):
def __init__(self, first_value: np.ndarray, track=False):
self.n = 1
self.mean = first_value
self._m_2 = np.zeros_like(first_value)
def update(self, new_value: np.ndarray):
self._tracker: Optional[SortedList] = None
if track:
self._tracker = SortedList()
def update(self, new_value: np.ndarray, i: Optional[int] = None):
if self._tracker is not None:
if i is None:
raise ValueError("Tracking is enabled but no index was supplied.")
if self.has_sample(i):
return
self._tracker.add(i)
self.n += 1
delta = new_value - self.mean
self.mean += delta / self.n
delta2 = new_value - self.mean
self._m_2 += np.abs(delta) * np.abs(delta2)
def has_sample(self, i: int) -> bool:
if self._tracker is None:
return False # don't know
return i in self._tracker
@property
def sample_variance(self) -> np.ndarray:
if self.n == 1:
@ -778,9 +797,7 @@ def load_online_cache(save: str):
def ensemble_mean_online(
args: Any,
save: str,
function: Callable[..., np.ndarray],
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
) -> Optional[EnsembleValue]:
path = get_online_data_path(save)
@ -796,13 +813,13 @@ def ensemble_mean_online(
with path.open("rb") as agg_file:
aggregate = pickle.load(agg_file)
if result is not None:
aggregate.update(result)
aggregate.update(result, i)
else:
if result is None:
raise RuntimeError("No cache and no result.")
aggregate = WelfordAggregator(result)
aggregate = WelfordAggregator(result, i is not None)
with path.open("wb") as agg_file:
pickle.dump(aggregate, agg_file)