mirror of
https://github.com/vale981/bandfit
synced 2025-03-05 09:31:42 -05:00
fix scaling of params
This commit is contained in:
parent
7908745b63
commit
3148c054a5
1 changed files with 28 additions and 13 deletions
|
@ -7,6 +7,7 @@ from skimage.transform import resize
|
|||
from skimage.filters import hessian
|
||||
from skimage.morphology import skeletonize
|
||||
from itertools import chain
|
||||
from scipy.interpolate import splrep, BSpline
|
||||
|
||||
|
||||
def load_data(data_file: str) -> np.ndarray:
|
||||
|
@ -236,14 +237,15 @@ def plot_data_with_bands(data, bands):
|
|||
# return sc.optimize.curve_fit(double_lorentzian, e_axis, col, (0, 10, 0, 3))
|
||||
|
||||
|
||||
def candidate(k, a, b, c, d, s, k_scale):
|
||||
k = np.asarray(k) * k_scale
|
||||
energies = energy(k, a, b, c, d)
|
||||
def candidate(k, b, c, d, k_scale, k_shift):
|
||||
k = np.asarray(k[: k.size // 2]) * k_scale + k_shift
|
||||
energies = energy(k, 1 - b - c - d, b, c, d)
|
||||
energies /= energies.max()
|
||||
return energies * s
|
||||
|
||||
return np.hstack([energies, energies])
|
||||
|
||||
|
||||
def fit_to_bands(bands, a=1, b=1, c=1, d=1):
|
||||
def fit_to_bands(bands, b=1, c=1, d=1):
|
||||
bands_normalized = bands.copy()
|
||||
|
||||
bands_normalized[:, :2] -= np.sum(bands_normalized[:, :2], axis=1).mean() / 2
|
||||
|
@ -254,15 +256,28 @@ def fit_to_bands(bands, a=1, b=1, c=1, d=1):
|
|||
|
||||
plt.plot(ks, bands_normalized[:, 0])
|
||||
plt.plot(ks, bands_normalized[:, 1])
|
||||
p, _ = sc.optimize.curve_fit(
|
||||
p, cov = sc.optimize.curve_fit(
|
||||
candidate,
|
||||
ks,
|
||||
bands_normalized[:, 0],
|
||||
(a, b, c, d, 1, 1),
|
||||
sigma=bands_normalized[:, 2],
|
||||
bounds=[(0.5, 0.5, -5, -5, 0.5, 0.5), (1.5, 1.5, 5, 5, 1.5, 1.5)],
|
||||
np.hstack([ks, ks]),
|
||||
np.hstack([bands_normalized[:, 0], bands_normalized[:, 1]]),
|
||||
(b, c, d, 1, 0),
|
||||
sigma=np.hstack([bands_normalized[:, 2], bands_normalized[:, 3]]),
|
||||
bounds=[(-10, -10, -10, 0.9, -0.5), (10, 10, 10, 1.1, 0.5)],
|
||||
)
|
||||
|
||||
plt.plot(ks, candidate(ks, *p))
|
||||
plt.plot(ks, candidate(np.hstack([ks, ks]), *p)[: bands.shape[0]])
|
||||
(b, c, d, k_scale, k_shift) = p
|
||||
a = 1 - b - c - d
|
||||
|
||||
return p, _
|
||||
scale = 1 / a
|
||||
|
||||
a *= scale
|
||||
b *= scale
|
||||
c *= scale
|
||||
|
||||
σ = np.sqrt(np.diag(cov))
|
||||
σ = np.array([np.sqrt(sum(σ[:3] ** 2)), *σ])
|
||||
|
||||
σ[:4] *= scale
|
||||
|
||||
return ((a, b, c, d, k_scale, k_shift), σ)
|
||||
|
|
Loading…
Add table
Reference in a new issue