mirror of
https://github.com/vale981/hopsflow
synced 2025-03-05 16:51:39 -05:00
mean value with welford for stability
This commit is contained in:
parent
8264b97fee
commit
3e8bbac9c3
1 changed files with 38 additions and 44 deletions
|
@ -7,6 +7,7 @@ import scipy
|
|||
from typing import Iterator, Optional, Any, Callable, Union
|
||||
from lmfit import minimize, Parameters
|
||||
from numpy.polynomial import Polynomial
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def apply_operator(ψ: np.ndarray, op: np.ndarray) -> np.ndarray:
|
||||
|
@ -149,7 +150,34 @@ def _ensemble_mean_init(func: Callable, args: tuple, kwargs: dict):
|
|||
|
||||
|
||||
# TODO: Use paramspec
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class WelfordAggregator:
|
||||
__slots__ = ["n", "mean", "_m_2"]
|
||||
|
||||
def __init__(self, first_value: np.ndarray):
|
||||
self.n = 1
|
||||
self.mean = first_value
|
||||
self._m_2 = np.zeros_like(first_value)
|
||||
|
||||
def update(self, new_value: np.ndarray):
|
||||
self.n += 1
|
||||
delta = new_value - self.mean
|
||||
self.mean += delta / self.n
|
||||
delta2 = new_value - self.mean
|
||||
self._m_2 += delta * delta2
|
||||
|
||||
@property
|
||||
def sample_variance(self) -> np.ndarray:
|
||||
return self._m_2 / (self.n - 1)
|
||||
|
||||
@property
|
||||
def ensemble_variance(self) -> np.ndarray:
|
||||
return self.sample_variance / self.n
|
||||
|
||||
@property
|
||||
def ensemble_std(self) -> np.ndarray:
|
||||
return np.sqrt(self.ensemble_variance)
|
||||
|
||||
|
||||
def ensemble_mean(
|
||||
|
@ -160,18 +188,10 @@ def ensemble_mean(
|
|||
const_kwargs: dict = dict(),
|
||||
n_proc: Optional[int] = None,
|
||||
every: Optional[int] = None,
|
||||
calculate_variance: bool = False,
|
||||
):
|
||||
|
||||
first = function(next(arg_iter), *const_args)
|
||||
result = np.zeros(
|
||||
tuple([1] if every is None else [int(N / every) + 1]) + first.shape,
|
||||
dtype=first.dtype,
|
||||
)
|
||||
result[-1] = first
|
||||
|
||||
if calculate_variance:
|
||||
vars = np.zeros_like(result)
|
||||
results = []
|
||||
aggregate = WelfordAggregator(function(next(arg_iter), *const_args))
|
||||
|
||||
if not n_proc:
|
||||
n_proc = multiprocessing.cpu_count()
|
||||
|
@ -187,44 +207,18 @@ def ensemble_mean(
|
|||
10,
|
||||
)
|
||||
|
||||
n = 1
|
||||
ns = []
|
||||
for res in tqdm(result_iter, total=(N - 1)):
|
||||
result[-1] += res
|
||||
if calculate_variance:
|
||||
vars[-1] += res ** 2
|
||||
aggregate.update(res)
|
||||
|
||||
n += 1
|
||||
|
||||
if every is not None and (n % every) == 0:
|
||||
ns.append(n)
|
||||
index = int(n / every) - 1
|
||||
result[index] = result[-1].copy() / n
|
||||
|
||||
if calculate_variance:
|
||||
vars[index] = np.sqrt(
|
||||
((vars[-1].copy()) / n - result[index] ** 2) / (n - 1)
|
||||
)
|
||||
|
||||
ns.append(n)
|
||||
|
||||
result[-1] /= n
|
||||
if calculate_variance:
|
||||
vars[-1] = np.sqrt((vars[-1] / n - result[-1] ** 2) / (n - 1))
|
||||
if every is not None and (aggregate.n % every) == 0 or aggregate.n == N:
|
||||
results.append(
|
||||
(aggregate.n, aggregate.mean.copy(), aggregate.ensemble_std.copy())
|
||||
)
|
||||
|
||||
if not every:
|
||||
result = result[-1]
|
||||
if calculate_variance:
|
||||
vars = vars[-1]
|
||||
results = results[-1]
|
||||
|
||||
retval = [result]
|
||||
if calculate_variance:
|
||||
retval = retval + [vars]
|
||||
|
||||
if every:
|
||||
retval = retval + [ns]
|
||||
|
||||
return tuple(retval)
|
||||
return results
|
||||
|
||||
|
||||
def fit_α(
|
||||
|
|
Loading…
Add table
Reference in a new issue