From 60f1394f34e2835c74abdda12b9c1397a8303496 Mon Sep 17 00:00:00 2001 From: Valentin Boettcher Date: Tue, 21 Mar 2023 12:54:25 -0400 Subject: [PATCH] fix ordering and clean up --- bandfit/bandfit.py | 50 +++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/bandfit/bandfit.py b/bandfit/bandfit.py index 16cf64b..0747847 100644 --- a/bandfit/bandfit.py +++ b/bandfit/bandfit.py @@ -258,7 +258,13 @@ def candidate(k, c, d, a, δb, k_scale, k_shift): return np.hstack([energies, energies]) -def fit_to_bands(bands, a=1, δb=0, c=10, d=10, ic_scan_steps=5, c_d_order=0): +def fit_to_bands( + bands, + bounds=[(-10, -10, 0.1, -0.5, 0.9, -0.5), (10, 10, 10, 0.5, 1.1, 0.5)], + ic_scan_steps=5, + c_d_order=0, + debug_plots=False, +): bands_normalized = bands.copy() bands_normalized[:, :2] -= np.sum(bands_normalized[:, :2], axis=1).mean() / 2 @@ -267,10 +273,11 @@ def fit_to_bands(bands, a=1, δb=0, c=10, d=10, ic_scan_steps=5, c_d_order=0): ks = np.linspace(-np.pi, np.pi, bands_normalized.shape[0]) - plt.plot(ks, bands_normalized[:, 0]) - plt.plot(ks, bands_normalized[:, 1]) + if debug_plots: + plt.plot(ks, bands_normalized[:, 0]) + plt.plot(ks, bands_normalized[:, 1]) - bounds = np.array([(-10, -10, 0.1, -0.5, 0.9, -0.5), (10, 10, 10, 0.5, 1.1, 0.5)]) + bounds = np.array(bounds) Δ_bounds = bounds[1, :2] - bounds[0, :2] ics = np.tile(np.linspace(0, 1, ic_scan_steps), (2, 1)) @@ -278,14 +285,11 @@ def fit_to_bands(bands, a=1, δb=0, c=10, d=10, ic_scan_steps=5, c_d_order=0): ics += bounds[0, :2][:, None] min_δb = np.inf + (c, d, a, δb, k_scale, k_shift) = np.zeros(6) + cov = np.zeros(6) + for ic in itertools.product(*ics): - if c_d_order == 1 and ic[0] > ic[1]: - continue - - if c_d_order == -1 and ic[0] < ic[1]: - continue - - p, cov_, _, _, success = sc.optimize.curve_fit( + p, current_cov, _, _, success = sc.optimize.curve_fit( candidate, np.hstack([ks, ks]), np.hstack([bands_normalized[:, 0], bands_normalized[:, 1]]), @@ -298,23 +302,29 @@ def fit_to_bands(bands, a=1, δb=0, c=10, d=10, ic_scan_steps=5, c_d_order=0): if success < 1 or success > 4: continue + if c_d_order == 1 and p[0] > p[1]: + continue + + if c_d_order == -1 and p[0] < p[1]: + continue + if ( abs(p[3]) < min_δb - and np.sqrt(np.sum(np.diag(cov_))) / np.linalg.norm(p) < 0.1 + and np.sqrt(np.sum(np.diag(current_cov))) / np.linalg.norm(p) < 0.1 ): - print(ic) - print("hey", p, p[3], min_δb) - - (a, c, d, δb, k_scale, k_shift) = p + (c, d, a, δb, k_scale, k_shift) = p min_δb = abs(δb) - cov = cov_ + cov = current_cov - plt.plot(ks, candidate(np.hstack([ks, ks]), *p)[: bands.shape[0]]) + if debug_plots: + plt.plot(ks, candidate(np.hstack([ks, ks]), *p)[: bands.shape[0]]) b = a + δb * a - σ = np.sqrt(np.diag(cov)) - σ[1] = np.sqrt((σ[0] * (1 + δb)) ** 2 + (a * σ[1]) ** 2) + σ_c, σ_d, σ_a, σ_δb, σ_k_scale, σ_k_shift = np.sqrt(np.diag(cov)) + + σ_b = np.sqrt((σ_a * (1 + δb)) ** 2 + (a * σ_δb) ** 2) + σ = np.array((σ_a, σ_b, σ_c, σ_d, σ_k_scale, σ_k_shift)) scale = 1 / a