import stg_helper from types import ModuleType from typing import Callable, Tuple, Union, Iterator from lmfit import minimize, Parameters import matplotlib.pyplot as plt import matplotlib.ticker as ticker import matplotlib import numpy as np from numpy.polynomial import Polynomial import functools from contextlib import contextmanager from pathlib import Path import h5py import pickle def get_n_samples(stg: ModuleType) -> int: """Get the number of samples from ``stg``.""" with stg_helper.get_hierarchy_data(stg, read_only=True) as hd: return hd.get_samples() def has_all_samples(stg: ModuleType) -> bool: return stg.__HI_number_of_samples == get_n_samples(stg) def has_all_samples_checker(stg: ModuleType) -> Callable[..., bool]: return "Has all samples?", lambda _: has_all_samples(stg) def peruse_hierarchy_files(base: str) -> Iterator[h5py.File]: p = Path(base) for i in p.glob("*/*.h5"): f = h5py.File(i, "r") yield f f.close() def α_apprx(τ, g, w): return np.sum( g[np.newaxis, :] * np.exp(-w[np.newaxis, :] * (τ[:, np.newaxis])), axis=1 ) def fit_α( α: Callable[[np.ndarray], np.ndarray], n: int, t_max: float, support_points: Union[int, np.ndarray] = 1000, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit the BCF ``α`` to a sum of ``n`` exponentials up to ``t_max`` using a number of ``support_points``. """ def residual(fit_params, x, data): resid = 0 w = np.array([fit_params[f"w{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"wi{i}"] for i in range(n)] ) g = np.array([fit_params[f"g{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"gi{i}"] for i in range(n)] ) resid = data - α_apprx(x, g, w) return resid.view(float) fit_params = Parameters() for i in range(n): fit_params.add(f"g{i}", value=0.1) fit_params.add(f"gi{i}", value=0.1) fit_params.add(f"w{i}", value=0.1) fit_params.add(f"wi{i}", value=0.1) ts = np.asarray(support_points) if ts.size < 2: ts = np.linspace(0, t_max, support_points) out = minimize(residual, fit_params, args=(ts, α(ts))) w = np.array([out.params[f"w{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"wi{i}"] for i in range(n)] ) g = np.array([out.params[f"g{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"gi{i}"] for i in range(n)] ) return w, g ############################################################################### # Plot Porn # ############################################################################### def wrap_plot(f): def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs): fig = None if not ax: fig, ax = setup_function() ret_val = f(*args, ax=ax, **kwargs) return (fig, ax, ret_val) if ret_val else (fig, ax) return wrapped @contextmanager def hiro_style(*args, **kwargs): with plt.style.context("ggplot"): with matplotlib.rc_context( { # "font.family": "serif", "text.usetex": False, "pgf.rcfonts": False, "lines.linewidth": 1, } ): yield True @wrap_plot def plot_complex(x, y, *args, ax=None, label="", **kwargs): label = label + ", " if (len(label) > 0) else "" ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs) ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs) ax.legend() @wrap_plot def plot_convergence( x, y, *args, ax=None, label="", transform=lambda y: y, slice=None, **kwargs ): label = label + ", " if (len(label) > 0) else "" slice = (0, -1) if not slice else slice for n, val, std in y[slice[0] : slice[1]]: plt.plot( x, transform(val), label=f"{label}n={n}", alpha=n / y[-1][0], linestyle="--" ) ax.errorbar( x, transform(y[-1][1]), yerr=y[-1][2], ecolor="yellow", label=f"{label}n={y[-1][0]}", color="red", ) return None @wrap_plot def plot_diff_vs_sigma( x, y, reference, *args, ax=None, label="", transform=lambda y: y, ecolor="yellow", **kwargs, ): label = label + ", " if (len(label) > 0) else "" ax.fill_between( x, 0, y[-1][2], color=ecolor, label=f"{label}$\sigma$", ) for n, val, std in y: plt.plot( x, np.abs(transform(val) - reference), label=f"{label}n={n}", alpha=n / y[-1][0], ) ############################################################################### # Numpy Hacks # ############################################################################### def e_i(i: int, size: int) -> np.ndarray: r"""Cartesian base vector :math:`e_i`.""" vec = np.zeros(size) vec[i] = 1 return vec def except_element(array: np.ndarray, index: int) -> np.ndarray: mask = [i != index for i in range(array.size)] return array[mask] def poly_real(p: Polynomial) -> Polynomial: """Return the real part of ``p``.""" new = p.copy() new.coef = p.coef.real return new