from types import ModuleType from typing import Callable, Tuple, Union, Iterator from lmfit import minimize, Parameters import matplotlib.pyplot as plt import matplotlib import numpy as np from numpy.polynomial import Polynomial from contextlib import contextmanager from pathlib import Path import h5py from hopsflow import hopsflow # def get_n_samples(config: ModuleType) -> int: # """Get the number of samples from ``stg``.""" # with stg_helper.get_hierarchy_data(stg, read_only=True) as hd: # samp = hd.get_samples() # return samp if isinstance(samp, int) else 0 # def has_all_samples(stg: ModuleType) -> bool: # return stg.__HI_number_of_samples == get_n_samples(stg) # def has_all_samples_checker(stg: ModuleType) -> Tuple[str, Callable[..., bool]]: # return "Has all samples?", lambda _: has_all_samples(stg) # def hopsflow_systemparams(stg: ModuleType): # system_params = stg_helper.get_system_param(stg) # return hopsflow.SystemParams( # system_params.L.todense(), stg.__g, stg.__w, stg.__bcf_scale, stg.__HI_nonlinear # ) # def hopsflow_thermparams(stg: ModuleType, τ: np.ndarray): # ξ = stg_helper.get_eta_therm(stg) # ξ.calc_deriv = True # return hopsflow.ThermalParams( # ξ=ξ, # τ=τ, # num_deriv=False, # rand_skip=stg.__HI_rand_skip if hasattr(stg, "__HI_rand_skip") else 0, # ) def peruse_hierarchy_files(base: str) -> Iterator[h5py.File]: p = Path(base) for i in p.glob("*/*.h5"): f = h5py.File(i, "r") yield f f.close() def α_apprx(τ, g, w): return np.sum( g[np.newaxis, :] * np.exp(-w[np.newaxis, :] * (τ[:, np.newaxis])), axis=1 ) def fit_α( α: Callable[[np.ndarray], np.ndarray], n: int, t_max: float, support_points: Union[int, np.ndarray] = 1000, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit the BCF ``α`` to a sum of ``n`` exponentials up to ``t_max`` using a number of ``support_points``. """ def residual(fit_params, x, data): resid = 0 w = np.array([fit_params[f"w{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"wi{i}"] for i in range(n)] ) g = np.array([fit_params[f"g{i}"] for i in range(n)]) + 1j * np.array( [fit_params[f"gi{i}"] for i in range(n)] ) resid = data - α_apprx(x, g, w) return resid.view(float) fit_params = Parameters() for i in range(n): fit_params.add(f"g{i}", value=0.1) fit_params.add(f"gi{i}", value=0.1) fit_params.add(f"w{i}", value=0.1) fit_params.add(f"wi{i}", value=0.1) ts = np.asarray(support_points) if ts.size < 2: ts = np.linspace(0, t_max, support_points) out = minimize(residual, fit_params, args=(ts, α(ts))) w = np.array([out.params[f"w{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"wi{i}"] for i in range(n)] ) g = np.array([out.params[f"g{i}"] for i in range(n)]) + 1j * np.array( [out.params[f"gi{i}"] for i in range(n)] ) return w, g ############################################################################### # Plot Porn # ############################################################################### def wrap_plot(f): def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs): fig = None if not ax: fig, ax = setup_function() ret_val = f(*args, ax=ax, **kwargs) return (fig, ax, ret_val) if ret_val else (fig, ax) return wrapped @contextmanager def hiro_style(): with plt.style.context("ggplot"): with matplotlib.rc_context( { # "font.family": "serif", "text.usetex": False, "pgf.rcfonts": False, "lines.linewidth": 1, } ): yield True @wrap_plot def plot_complex(x, y, *args, ax=None, label="", **kwargs): label = label + ", " if (len(label) > 0) else "" ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs) ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs) ax.legend() @wrap_plot def plot_convergence(x, y, ax=None, label="", transform=lambda y: y, slice=None): label = label + ", " if (len(label) > 0) else "" slice = (0, -1) if not slice else slice for n, val, _ in y[slice[0] : slice[1]]: plt.plot( x, transform(val), label=f"{label}n={n}", alpha=n / y[-1][0], linestyle="--" ) ax.errorbar( x, transform(y[-1][1]), yerr=y[-1][2], ecolor="yellow", label=f"{label}n={y[-1][0]}", color="red", ) return None @wrap_plot def plot_diff_vs_sigma( x, y, reference, ax=None, label="", transform=lambda y: y, ecolor="yellow", ): label = label + ", " if (len(label) > 0) else "" ax.fill_between( x, 0, y[-1][2], color=ecolor, label=fr"{label}$\sigma$", ) for n, val, _ in y: diff = np.abs(transform(val) - reference) within = (diff < y[-1][2]).sum() / y[-1][2].size ax.plot( x, diff, label=fr"{label}n={n} $\Delta<\sigma = {within * 100}\%$", alpha=n / y[-1][0], ) ############################################################################### # Numpy Hacks # ############################################################################### def e_i(i: int, size: int) -> np.ndarray: r"""Cartesian base vector :math:`e_i`.""" vec = np.zeros(size) vec[i] = 1 return vec def except_element(array: np.ndarray, index: int) -> np.ndarray: mask = [i != index for i in range(array.size)] return array[mask] def poly_real(p: Polynomial) -> Polynomial: """Return the real part of ``p``.""" new = p.copy() new.coef = p.coef.real return new