mirror of
https://github.com/vale981/hopsflow
synced 2025-03-05 08:51:37 -05:00
add tracking to welford
This commit is contained in:
parent
d75ca67126
commit
f9e13e0e1b
1 changed files with 25 additions and 8 deletions
|
@ -674,20 +674,39 @@ def integrate_array(
|
||||||
|
|
||||||
|
|
||||||
class WelfordAggregator:
|
class WelfordAggregator:
|
||||||
__slots__ = ["n", "mean", "_m_2"]
|
__slots__ = ["n", "mean", "_m_2", "_tracker"]
|
||||||
|
|
||||||
def __init__(self, first_value: np.ndarray):
|
def __init__(self, first_value: np.ndarray, track=False):
|
||||||
self.n = 1
|
self.n = 1
|
||||||
self.mean = first_value
|
self.mean = first_value
|
||||||
self._m_2 = np.zeros_like(first_value)
|
self._m_2 = np.zeros_like(first_value)
|
||||||
|
|
||||||
def update(self, new_value: np.ndarray):
|
self._tracker: Optional[SortedList] = None
|
||||||
|
if track:
|
||||||
|
self._tracker = SortedList()
|
||||||
|
|
||||||
|
def update(self, new_value: np.ndarray, i: Optional[int] = None):
|
||||||
|
if self._tracker is not None:
|
||||||
|
if i is None:
|
||||||
|
raise ValueError("Tracking is enabled but no index was supplied.")
|
||||||
|
|
||||||
|
if self.has_sample(i):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._tracker.add(i)
|
||||||
|
|
||||||
self.n += 1
|
self.n += 1
|
||||||
delta = new_value - self.mean
|
delta = new_value - self.mean
|
||||||
self.mean += delta / self.n
|
self.mean += delta / self.n
|
||||||
delta2 = new_value - self.mean
|
delta2 = new_value - self.mean
|
||||||
self._m_2 += np.abs(delta) * np.abs(delta2)
|
self._m_2 += np.abs(delta) * np.abs(delta2)
|
||||||
|
|
||||||
|
def has_sample(self, i: int) -> bool:
|
||||||
|
if self._tracker is None:
|
||||||
|
return False # don't know
|
||||||
|
|
||||||
|
return i in self._tracker
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_variance(self) -> np.ndarray:
|
def sample_variance(self) -> np.ndarray:
|
||||||
if self.n == 1:
|
if self.n == 1:
|
||||||
|
@ -778,9 +797,7 @@ def load_online_cache(save: str):
|
||||||
|
|
||||||
|
|
||||||
def ensemble_mean_online(
|
def ensemble_mean_online(
|
||||||
args: Any,
|
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
|
||||||
save: str,
|
|
||||||
function: Callable[..., np.ndarray],
|
|
||||||
) -> Optional[EnsembleValue]:
|
) -> Optional[EnsembleValue]:
|
||||||
path = get_online_data_path(save)
|
path = get_online_data_path(save)
|
||||||
|
|
||||||
|
@ -796,13 +813,13 @@ def ensemble_mean_online(
|
||||||
with path.open("rb") as agg_file:
|
with path.open("rb") as agg_file:
|
||||||
aggregate = pickle.load(agg_file)
|
aggregate = pickle.load(agg_file)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
aggregate.update(result)
|
aggregate.update(result, i)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if result is None:
|
if result is None:
|
||||||
raise RuntimeError("No cache and no result.")
|
raise RuntimeError("No cache and no result.")
|
||||||
|
|
||||||
aggregate = WelfordAggregator(result)
|
aggregate = WelfordAggregator(result, i is not None)
|
||||||
|
|
||||||
with path.open("wb") as agg_file:
|
with path.open("wb") as agg_file:
|
||||||
pickle.dump(aggregate, agg_file)
|
pickle.dump(aggregate, agg_file)
|
||||||
|
|
Loading…
Add table
Reference in a new issue