From 9363e71898c3eb409032b7efb96503a13387f537 Mon Sep 17 00:00:00 2001 From: Valentin Boettcher Date: Tue, 21 Mar 2023 14:51:06 -0400 Subject: [PATCH] optimize bounds and variance filtering --- bandfit/bandfit.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/bandfit/bandfit.py b/bandfit/bandfit.py index 39bd81d..31e5e39 100644 --- a/bandfit/bandfit.py +++ b/bandfit/bandfit.py @@ -130,7 +130,7 @@ def candidate(k, c, d, a, δb, k_scale, k_shift): def fit_to_bands( bands, - bounds=[(-10, -10, 0.1, -0.5, 0.8, -0.5), (10, 10, 10, 0.5, 1.2, 0.5)], + bounds=[(-10, -10, 0.5, -0.5, 0.8, -0.5), (10, 10, 10, 0.5, 1.2, 0.5)], ic_scan_steps=5, c_d_order=0, debug_plots=False, @@ -158,6 +158,7 @@ def fit_to_bands( (c, d, a, δb, k_scale, k_shift) = np.zeros(6) cov = np.zeros(6) + σs = [] for ic in itertools.product(*ics): p, current_cov, _, _, success = sc.optimize.curve_fit( candidate, @@ -178,14 +179,17 @@ def fit_to_bands( if c_d_order == -1 and p[0] < p[1]: continue - if ( - abs(p[3]) < min_δb - and np.sqrt(np.sum(np.diag(current_cov))) / np.linalg.norm(p) < 0.1 + σ_δb_current = np.sqrt(current_cov[3, 3]) + σ_rel = np.sqrt(np.sum(np.diag(current_cov))) / np.linalg.norm(p) + if abs(p[3]) + σ_δb_current < min_δb and ( + len(σs) == 0 or σ_rel <= np.min(σs) * 2 ): (c, d, a, δb, k_scale, k_shift) = p - min_δb = abs(δb) + min_δb = abs(δb) + σ_δb_current cov = current_cov + σs.append(σ_rel) + if debug_plots: plt.plot(ks, candidate(np.hstack([ks, ks]), *p)[: bands.shape[0]])