import stg_helper from types import ModuleType from typing import Callable, Tuple, Union from lmfit import minimize, Parameters import matplotlib.pyplot as plt import numpy as np from numpy.polynomial import Polynomial def get_n_samples(stg: ModuleType) -> int: """Get the number of samples from ``stg``.""" with stg_helper.get_hierarchy_data(stg, read_only=True) as hd: return hd.get_samples() def has_all_samples(stg: ModuleType) -> bool: return stg.__HI_number_of_samples == get_n_samples(stg) def has_all_samples_checker(stg: ModuleType) -> Callable[..., bool]: return "Has all samples?", lambda _: has_all_samples(stg) def α_apprx(τ, g, w): return np.sum( g[np.newaxis, :] * np.exp(-w[np.newaxis, :] * (τ[:, np.newaxis])), axis=1 ) def fit_α( α: Callable[[np.ndarray], np.ndarray], n: int, t_max: float, support_points: Union[int, np.ndarray] = 1000, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit the BCF ``α`` to a sum of ``n`` exponentials up to ``t_max`` using a number of ``support_points``. """ def residual(fit_params, x, data): resid = 0 w = np.array([fit_params[f"w{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"wi{i}"] for i in range(n)] ) g = np.array([fit_params[f"g{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"gi{i}"] for i in range(n)] ) resid = data - α_apprx(x, g, w) return resid.view(float) fit_params = Parameters() for i in range(n): fit_params.add(f"g{i}", value=0.1) fit_params.add(f"gi{i}", value=0.1) fit_params.add(f"w{i}", value=0.1) fit_params.add(f"wi{i}", value=0.1) ts = np.asarray(support_points) if ts.size < 2: ts = np.linspace(0, t_max, support_points) out = minimize(residual, fit_params, args=(ts, α(ts))) w = np.array([out.params[f"w{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"wi{i}"] for i in range(n)] ) g = np.array([out.params[f"g{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"gi{i}"] for i in range(n)] ) return w, g def wrap_plot(f): def wrapped(*args, ax=None, **kwargs): fig = None if not ax: fig, ax = plt.subplots() ret_val = f(*args, ax=ax, **kwargs) return (fig, ax, ret_val) if ret_val else (fig, ax) return wrapped @wrap_plot def plot_complex(x, y, *args, ax=None, label="", **kwargs): label = label + ", " if (len(label) > 0) else "" ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs) ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs) ax.legend() def e_i(i: int, size: int) -> np.ndarray: r"""Cartesian base vector :math:`e_i`.""" vec = np.zeros(size) vec[i] = 1 return vec def except_element(array: np.ndarray, index: int) -> np.ndarray: mask = [i != index for i in range(array.size)] return array[mask] def poly_real(p: Polynomial) -> Polynomial: """Return the real part of ``p``.""" new = p.copy() new.coef = p.coef.real return new