master-thesis/python/energy_flow_proper/utilities.py

142 lines
4 KiB
Python
Raw Normal View History

2021-11-11 16:09:04 +01:00
import stg_helper
from types import ModuleType
2021-11-19 20:26:57 +01:00
from typing import Callable, Tuple, Union
from lmfit import minimize, Parameters
import matplotlib.pyplot as plt
2021-11-24 19:15:44 +01:00
import matplotlib.ticker as ticker
import matplotlib
2021-11-19 20:26:57 +01:00
import numpy as np
from numpy.polynomial import Polynomial
2021-11-24 19:15:44 +01:00
import functools
from contextlib import contextmanager
2021-11-19 20:26:57 +01:00
2021-11-11 16:09:04 +01:00
def get_n_samples(stg: ModuleType) -> int:
"""Get the number of samples from ``stg``."""
with stg_helper.get_hierarchy_data(stg, read_only=True) as hd:
return hd.get_samples()
2021-11-19 20:26:57 +01:00
2021-11-11 16:09:04 +01:00
def has_all_samples(stg: ModuleType) -> bool:
return stg.__HI_number_of_samples == get_n_samples(stg)
2021-11-19 20:26:57 +01:00
2021-11-11 16:09:04 +01:00
def has_all_samples_checker(stg: ModuleType) -> Callable[..., bool]:
return "Has all samples?", lambda _: has_all_samples(stg)
2021-11-19 20:26:57 +01:00
def α_apprx(τ, g, w):
return np.sum(
g[np.newaxis, :] * np.exp(-w[np.newaxis, :] * (τ[:, np.newaxis])), axis=1
)
def fit_α(
α: Callable[[np.ndarray], np.ndarray],
n: int,
t_max: float,
support_points: Union[int, np.ndarray] = 1000,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Fit the BCF ``α`` to a sum of ``n`` exponentials up to
``t_max`` using a number of ``support_points``.
"""
def residual(fit_params, x, data):
resid = 0
w = np.array([fit_params[f"w{i}"] for i in range(n)]) + 1j * np.array(
[fit_params[f"wi{i}"] for i in range(n)]
)
g = np.array([fit_params[f"g{i}"] for i in range(n)]) + 1j * np.array(
[fit_params[f"gi{i}"] for i in range(n)]
)
resid = data - α_apprx(x, g, w)
return resid.view(float)
fit_params = Parameters()
for i in range(n):
fit_params.add(f"g{i}", value=0.1)
fit_params.add(f"gi{i}", value=0.1)
fit_params.add(f"w{i}", value=0.1)
fit_params.add(f"wi{i}", value=0.1)
ts = np.asarray(support_points)
if ts.size < 2:
ts = np.linspace(0, t_max, support_points)
out = minimize(residual, fit_params, args=(ts, α(ts)))
w = np.array([out.params[f"w{i}"] for i in range(n)]) + 1j * np.array(
[out.params[f"wi{i}"] for i in range(n)]
)
g = np.array([out.params[f"g{i}"] for i in range(n)]) + 1j * np.array(
[out.params[f"gi{i}"] for i in range(n)]
)
return w, g
2021-11-24 19:15:44 +01:00
###############################################################################
# Plot Porn #
###############################################################################
2021-11-19 20:26:57 +01:00
def wrap_plot(f):
2021-11-24 19:15:44 +01:00
def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs):
2021-11-19 20:26:57 +01:00
fig = None
if not ax:
2021-11-24 19:15:44 +01:00
fig, ax = setup_function()
2021-11-19 20:26:57 +01:00
ret_val = f(*args, ax=ax, **kwargs)
return (fig, ax, ret_val) if ret_val else (fig, ax)
return wrapped
2021-11-24 19:15:44 +01:00
@contextmanager
def hiro_style(*args, **kwargs):
with plt.style.context("ggplot"):
with matplotlib.rc_context(
{
# "font.family": "serif",
"text.usetex": False,
"pgf.rcfonts": False,
"lines.linewidth": 1,
}
):
yield True
2021-11-19 20:26:57 +01:00
@wrap_plot
def plot_complex(x, y, *args, ax=None, label="", **kwargs):
label = label + ", " if (len(label) > 0) else ""
ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs)
ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs)
ax.legend()
2021-11-24 19:15:44 +01:00
###############################################################################
# Numpy Hacks #
###############################################################################
2021-11-19 20:26:57 +01:00
def e_i(i: int, size: int) -> np.ndarray:
r"""Cartesian base vector :math:`e_i`."""
vec = np.zeros(size)
vec[i] = 1
return vec
def except_element(array: np.ndarray, index: int) -> np.ndarray:
mask = [i != index for i in range(array.size)]
return array[mask]
def poly_real(p: Polynomial) -> Polynomial:
"""Return the real part of ``p``."""
new = p.copy()
new.coef = p.coef.real
return new