mirror of
https://github.com/vale981/hopsflow
synced 2025-03-04 16:31:38 -05:00
add tracking to welford
This commit is contained in:
parent
d75ca67126
commit
f9e13e0e1b
1 changed files with 25 additions and 8 deletions
|
@ -674,20 +674,39 @@ def integrate_array(
|
|||
|
||||
|
||||
class WelfordAggregator:
|
||||
__slots__ = ["n", "mean", "_m_2"]
|
||||
__slots__ = ["n", "mean", "_m_2", "_tracker"]
|
||||
|
||||
def __init__(self, first_value: np.ndarray):
|
||||
def __init__(self, first_value: np.ndarray, track=False):
|
||||
self.n = 1
|
||||
self.mean = first_value
|
||||
self._m_2 = np.zeros_like(first_value)
|
||||
|
||||
def update(self, new_value: np.ndarray):
|
||||
self._tracker: Optional[SortedList] = None
|
||||
if track:
|
||||
self._tracker = SortedList()
|
||||
|
||||
def update(self, new_value: np.ndarray, i: Optional[int] = None):
|
||||
if self._tracker is not None:
|
||||
if i is None:
|
||||
raise ValueError("Tracking is enabled but no index was supplied.")
|
||||
|
||||
if self.has_sample(i):
|
||||
return
|
||||
|
||||
self._tracker.add(i)
|
||||
|
||||
self.n += 1
|
||||
delta = new_value - self.mean
|
||||
self.mean += delta / self.n
|
||||
delta2 = new_value - self.mean
|
||||
self._m_2 += np.abs(delta) * np.abs(delta2)
|
||||
|
||||
def has_sample(self, i: int) -> bool:
|
||||
if self._tracker is None:
|
||||
return False # don't know
|
||||
|
||||
return i in self._tracker
|
||||
|
||||
@property
|
||||
def sample_variance(self) -> np.ndarray:
|
||||
if self.n == 1:
|
||||
|
@ -778,9 +797,7 @@ def load_online_cache(save: str):
|
|||
|
||||
|
||||
def ensemble_mean_online(
|
||||
args: Any,
|
||||
save: str,
|
||||
function: Callable[..., np.ndarray],
|
||||
args: Any, save: str, function: Callable[..., np.ndarray], i: Optional[int] = None
|
||||
) -> Optional[EnsembleValue]:
|
||||
path = get_online_data_path(save)
|
||||
|
||||
|
@ -796,13 +813,13 @@ def ensemble_mean_online(
|
|||
with path.open("rb") as agg_file:
|
||||
aggregate = pickle.load(agg_file)
|
||||
if result is not None:
|
||||
aggregate.update(result)
|
||||
aggregate.update(result, i)
|
||||
|
||||
else:
|
||||
if result is None:
|
||||
raise RuntimeError("No cache and no result.")
|
||||
|
||||
aggregate = WelfordAggregator(result)
|
||||
aggregate = WelfordAggregator(result, i is not None)
|
||||
|
||||
with path.open("wb") as agg_file:
|
||||
pickle.dump(aggregate, agg_file)
|
||||
|
|
Loading…
Add table
Reference in a new issue