fix scaling of params

This commit is contained in:
Valentin Boettcher 2023-03-20 21:16:57 -04:00
parent 7908745b63
commit 3148c054a5

View file

@ -7,6 +7,7 @@ from skimage.transform import resize
from skimage.filters import hessian from skimage.filters import hessian
from skimage.morphology import skeletonize from skimage.morphology import skeletonize
from itertools import chain from itertools import chain
from scipy.interpolate import splrep, BSpline
def load_data(data_file: str) -> np.ndarray: def load_data(data_file: str) -> np.ndarray:
@ -236,14 +237,15 @@ def plot_data_with_bands(data, bands):
# return sc.optimize.curve_fit(double_lorentzian, e_axis, col, (0, 10, 0, 3)) # return sc.optimize.curve_fit(double_lorentzian, e_axis, col, (0, 10, 0, 3))
def candidate(k, a, b, c, d, s, k_scale): def candidate(k, b, c, d, k_scale, k_shift):
k = np.asarray(k) * k_scale k = np.asarray(k[: k.size // 2]) * k_scale + k_shift
energies = energy(k, a, b, c, d) energies = energy(k, 1 - b - c - d, b, c, d)
energies /= energies.max() energies /= energies.max()
return energies * s
return np.hstack([energies, energies])
def fit_to_bands(bands, a=1, b=1, c=1, d=1): def fit_to_bands(bands, b=1, c=1, d=1):
bands_normalized = bands.copy() bands_normalized = bands.copy()
bands_normalized[:, :2] -= np.sum(bands_normalized[:, :2], axis=1).mean() / 2 bands_normalized[:, :2] -= np.sum(bands_normalized[:, :2], axis=1).mean() / 2
@ -254,15 +256,28 @@ def fit_to_bands(bands, a=1, b=1, c=1, d=1):
plt.plot(ks, bands_normalized[:, 0]) plt.plot(ks, bands_normalized[:, 0])
plt.plot(ks, bands_normalized[:, 1]) plt.plot(ks, bands_normalized[:, 1])
p, _ = sc.optimize.curve_fit( p, cov = sc.optimize.curve_fit(
candidate, candidate,
ks, np.hstack([ks, ks]),
bands_normalized[:, 0], np.hstack([bands_normalized[:, 0], bands_normalized[:, 1]]),
(a, b, c, d, 1, 1), (b, c, d, 1, 0),
sigma=bands_normalized[:, 2], sigma=np.hstack([bands_normalized[:, 2], bands_normalized[:, 3]]),
bounds=[(0.5, 0.5, -5, -5, 0.5, 0.5), (1.5, 1.5, 5, 5, 1.5, 1.5)], bounds=[(-10, -10, -10, 0.9, -0.5), (10, 10, 10, 1.1, 0.5)],
) )
plt.plot(ks, candidate(ks, *p)) plt.plot(ks, candidate(np.hstack([ks, ks]), *p)[: bands.shape[0]])
(b, c, d, k_scale, k_shift) = p
a = 1 - b - c - d
return p, _ scale = 1 / a
a *= scale
b *= scale
c *= scale
σ = np.sqrt(np.diag(cov))
σ = np.array([np.sqrt(sum(σ[:3] ** 2)), *σ])
σ[:4] *= scale
return ((a, b, c, d, k_scale, k_shift), σ)