import stg_helper from types import ModuleType from typing import Callable, Tuple, Union from lmfit import minimize, Parameters import matplotlib.pyplot as plt import matplotlib.ticker as ticker import matplotlib import numpy as np from numpy.polynomial import Polynomial import functools from contextlib import contextmanager def get_n_samples(stg: ModuleType) -> int: """Get the number of samples from ``stg``.""" with stg_helper.get_hierarchy_data(stg, read_only=True) as hd: return hd.get_samples() def has_all_samples(stg: ModuleType) -> bool: return stg.__HI_number_of_samples == get_n_samples(stg) def has_all_samples_checker(stg: ModuleType) -> Callable[..., bool]: return "Has all samples?", lambda _: has_all_samples(stg) def α_apprx(τ, g, w): return np.sum( g[np.newaxis, :] * np.exp(-w[np.newaxis, :] * (τ[:, np.newaxis])), axis=1 ) def fit_α( α: Callable[[np.ndarray], np.ndarray], n: int, t_max: float, support_points: Union[int, np.ndarray] = 1000, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit the BCF ``α`` to a sum of ``n`` exponentials up to ``t_max`` using a number of ``support_points``. """ def residual(fit_params, x, data): resid = 0 w = np.array([fit_params[f"w{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"wi{i}"] for i in range(n)] ) g = np.array([fit_params[f"g{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"gi{i}"] for i in range(n)] ) resid = data - α_apprx(x, g, w) return resid.view(float) fit_params = Parameters() for i in range(n): fit_params.add(f"g{i}", value=0.1) fit_params.add(f"gi{i}", value=0.1) fit_params.add(f"w{i}", value=0.1) fit_params.add(f"wi{i}", value=0.1) ts = np.asarray(support_points) if ts.size < 2: ts = np.linspace(0, t_max, support_points) out = minimize(residual, fit_params, args=(ts, α(ts))) w = np.array([out.params[f"w{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"wi{i}"] for i in range(n)] ) g = np.array([out.params[f"g{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"gi{i}"] for i in range(n)] ) return w, g ############################################################################### # Plot Porn # ############################################################################### def wrap_plot(f): def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs): fig = None if not ax: fig, ax = setup_function() ret_val = f(*args, ax=ax, **kwargs) return (fig, ax, ret_val) if ret_val else (fig, ax) return wrapped @contextmanager def hiro_style(*args, **kwargs): with plt.style.context("ggplot"): with matplotlib.rc_context( { # "font.family": "serif", "text.usetex": False, "pgf.rcfonts": False, "lines.linewidth": 1, } ): yield True @wrap_plot def plot_complex(x, y, *args, ax=None, label="", **kwargs): label = label + ", " if (len(label) > 0) else "" ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs) ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs) ax.legend() ############################################################################### # Numpy Hacks # ############################################################################### def e_i(i: int, size: int) -> np.ndarray: r"""Cartesian base vector :math:`e_i`.""" vec = np.zeros(size) vec[i] = 1 return vec def except_element(array: np.ndarray, index: int) -> np.ndarray: mask = [i != index for i in range(array.size)] return array[mask] def poly_real(p: Polynomial) -> Polynomial: """Return the real part of ``p``.""" new = p.copy() new.coef = p.coef.real return new