mean value with welford for stability

This commit is contained in:
Valentin Boettcher 2021-11-30 17:41:13 +01:00
parent 7994cda636
commit 9354bac05d

View file

@ -7,6 +7,7 @@ import scipy
from typing import Iterator, Optional, Any, Callable, Union
from lmfit import minimize, Parameters
from numpy.polynomial import Polynomial
from tqdm import tqdm
def apply_operator(ψ: np.ndarray, op: np.ndarray) -> np.ndarray:
@ -149,7 +150,34 @@ def _ensemble_mean_init(func: Callable, args: tuple, kwargs: dict):
# TODO: Use paramspec
from tqdm import tqdm
class WelfordAggregator:
__slots__ = ["n", "mean", "_m_2"]
def __init__(self, first_value: np.ndarray):
self.n = 1
self.mean = first_value
self._m_2 = np.zeros_like(first_value)
def update(self, new_value: np.ndarray):
self.n += 1
delta = new_value - self.mean
self.mean += delta / self.n
delta2 = new_value - self.mean
self._m_2 += delta * delta2
@property
def sample_variance(self) -> np.ndarray:
return self._m_2 / (self.n - 1)
@property
def ensemble_variance(self) -> np.ndarray:
return self.sample_variance / self.n
@property
def ensemble_std(self) -> np.ndarray:
return np.sqrt(self.ensemble_variance)
def ensemble_mean(
@ -160,18 +188,10 @@ def ensemble_mean(
const_kwargs: dict = dict(),
n_proc: Optional[int] = None,
every: Optional[int] = None,
calculate_variance: bool = False,
):
first = function(next(arg_iter), *const_args)
result = np.zeros(
tuple([1] if every is None else [int(N / every) + 1]) + first.shape,
dtype=first.dtype,
)
result[-1] = first
if calculate_variance:
vars = np.zeros_like(result)
results = []
aggregate = WelfordAggregator(function(next(arg_iter), *const_args))
if not n_proc:
n_proc = multiprocessing.cpu_count()
@ -187,44 +207,18 @@ def ensemble_mean(
10,
)
n = 1
ns = []
for res in tqdm(result_iter, total=(N - 1)):
result[-1] += res
if calculate_variance:
vars[-1] += res ** 2
aggregate.update(res)
n += 1
if every is not None and (n % every) == 0:
ns.append(n)
index = int(n / every) - 1
result[index] = result[-1].copy() / n
if calculate_variance:
vars[index] = np.sqrt(
((vars[-1].copy()) / n - result[index] ** 2) / (n - 1)
)
ns.append(n)
result[-1] /= n
if calculate_variance:
vars[-1] = np.sqrt((vars[-1] / n - result[-1] ** 2) / (n - 1))
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
results.append(
(aggregate.n, aggregate.mean.copy(), aggregate.ensemble_std.copy())
)
if not every:
result = result[-1]
if calculate_variance:
vars = vars[-1]
results = results[-1]
retval = [result]
if calculate_variance:
retval = retval + [vars]
if every:
retval = retval + [ns]
return tuple(retval)
return results
def fit_α(