master-thesis/python/energy_flow_proper/figsaver.py

477 lines
13 KiB
Python
Raw Normal View History

2022-03-15 10:57:58 +01:00
import os
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from contextlib import contextmanager
2022-04-07 16:36:40 +02:00
import hiro_models.model_auxiliary as aux
2022-03-15 10:57:58 +01:00
fig_path = Path(os.getcwd()) / "figures"
val_path = Path(os.getcwd()) / "values"
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i / inch for i in tupl[0])
else:
return tuple(i / inch for i in tupl)
def export_fig(name, fig=None):
fig_path.mkdir(parents=True, exist_ok=True)
if fig is None:
fig = plt.gcf()
fig.tight_layout()
fig.canvas.draw()
fig.savefig(fig_path / f"{name}.pdf")
fig.savefig(fig_path / f"{name}.svg")
fig.savefig(fig_path / f"{name}.pgf")
return fig
def scientific_round(val, *err, retprec=False):
"""Scientifically rounds the values to the given errors."""
val, err = np.asarray(val), np.asarray(err)
if len(err.shape) == 1:
err = np.array([err])
err = err.T
err = err.T
if err.size == 1 and val.size > 1:
err = np.ones_like(val) * err
if len(err.shape) == 0:
err = np.array([err])
if val.size == 1 and err.shape[0] > 1:
val = np.ones_like(err) * val
i = np.floor(np.log10(err))
2022-04-07 16:34:37 +02:00
first_digit = (err // 10 ** i).astype(int)
2022-03-15 10:57:58 +01:00
prec = (-i + np.ones_like(err) * (first_digit <= 3)).astype(int)
prec = np.max(prec, axis=1)
def smart_round(value, precision):
value = np.round(value, precision)
if precision <= 0:
value = value.astype(int)
return value
if val.size > 1:
rounded = np.empty_like(val)
rounded_err = np.empty_like(err)
for n, (value, error, precision) in enumerate(zip(val, err, prec)):
rounded[n] = smart_round(value, precision)
rounded_err[n] = smart_round(error, precision)
if retprec:
return rounded, rounded_err, prec
else:
return rounded, rounded_err
else:
prec = prec[0]
if retprec:
return (smart_round(val, prec), *smart_round(err, prec)[0], prec)
else:
return (smart_round(val, prec), *smart_round(err, prec)[0])
def tex_value(val, err=None, unit=None, prefix="", suffix="", prec=0, save=None):
"""Generates LaTeX output of a value with units and error."""
if err:
val, err, prec = scientific_round(val, err, retprec=True)
else:
val = np.round(val, prec)
if prec == 0:
val = int(val)
if err:
err = int(err)
2022-03-23 12:55:05 +01:00
val_string = rf"{val:.{prec}f}" if prec > 0 else str(val)
2022-03-15 10:57:58 +01:00
if err:
2022-03-23 12:55:05 +01:00
val_string += rf"\pm {err:.{prec}f}" if prec > 0 else str(err)
2022-03-15 10:57:58 +01:00
ret_string = r"\(" + prefix
if unit is None:
ret_string += val_string
else:
2022-03-23 12:55:05 +01:00
ret_string += rf"\SI{{{val_string}}}{{{unit}}}"
2022-03-15 10:57:58 +01:00
ret_string += suffix + r"\)"
if save is not None:
val_path.mkdir(parents=True, exist_ok=True)
with open(val_path / f"{save}.tex", "w") as f:
f.write(ret_string)
return ret_string
###############################################################################
# Plot Porn #
###############################################################################
def wrap_plot(f):
def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs):
fig = None
if not ax:
fig, ax = setup_function()
ret_val = f(*args, ax=ax, **kwargs)
return (fig, ax, ret_val) if ret_val else (fig, ax)
return wrapped
def get_figsize(
2022-03-24 19:58:05 +01:00
columnwidth: float,
wf: float = 0.5,
2022-04-07 16:34:37 +02:00
hf: float = (5.0 ** 0.5 - 1.0) / 2.0,
2022-03-24 19:58:05 +01:00
) -> tuple[float, float]:
"""
:param wf: Width fraction in columnwidth units.
:param hf: Weight fraction in columnwidth units.
Set by default to golden ratio.
:param columnwidth: width of the column in latex.
Get this from LaTeX using \showthe\columnwidth
:returns: The ``[fig_width,fig_height]`` that should be given to
matplotlib.
2022-03-15 10:57:58 +01:00
"""
fig_width_pt = columnwidth * wf
inches_per_pt = 1.0 / 72.27 # Convert pt to inch
fig_width = fig_width_pt * inches_per_pt # width in inches
fig_height = fig_width * hf # height in inches
return [fig_width, fig_height]
@wrap_plot
def plot_complex(x, y, *args, ax=None, label="", **kwargs):
label = label + ", " if (len(label) > 0) else ""
ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs)
ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs)
ax.legend()
@wrap_plot
def plot_convergence(
x,
y,
ax=None,
label="",
transform=lambda y: y,
slice=None,
linestyle="-",
bath=None,
):
label = label + ", " if (len(label) > 0) else ""
slice = (0, -1) if not slice else slice
for n, val, _ in y[slice[0] : slice[1]]:
2022-03-23 12:55:05 +01:00
ax.plot(
2022-03-15 10:57:58 +01:00
x,
transform(val[bath] if bath is not None else val),
label=f"{label}$N={n}$",
alpha=n / y[-1][0],
linestyle=":",
)
2022-03-23 12:55:05 +01:00
err = (y[-1][2][bath] if bath is not None else y[-1][2]).real
y_final = transform(y[-1][1][bath] if bath is not None else y[-1][1])
ax.fill_between(
x,
y_final + err,
y_final - err,
color="yellow",
alpha=0.5,
label=f"{label}($σ$)",
)
ax.plot(
x, y_final, label=f"{label}$N={y[-1][0]}$", linestyle=linestyle, color="red"
)
return None
def lighten_color(color, amount=0.5):
"""
Lightens the given color by multiplying (1-luminosity) by the given amount.
Input can be matplotlib color string, hex string, or RGB tuple.
Examples:
>> lighten_color('g', 0.3)
>> lighten_color('#F034A3', 0.6)
>> lighten_color((.3,.55,.1), 0.5)
"""
import matplotlib.colors as mc
import colorsys
try:
c = mc.cnames[color]
except:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
def fancy_error(x, y, err, ax=None, **kwargs):
2022-03-23 12:55:05 +01:00
line = ax.plot(
x,
y,
2022-03-23 12:55:05 +01:00
**kwargs,
)
err = ax.fill_between(
2022-03-15 10:57:58 +01:00
x,
y + err,
y - err,
2022-03-23 12:55:05 +01:00
color=lighten_color(line[0].get_color(), 0.5),
alpha=0.5,
2022-03-15 10:57:58 +01:00
)
return line, err
@wrap_plot
def plot_with_σ(x, y, ax=None, transform=lambda y: y, bath=None, **kwargs):
err = (y.σ[bath] if bath is not None else y.σ).real
y_final = transform(y.value[bath] if bath is not None else y.value)
2022-03-15 10:57:58 +01:00
return fancy_error(x, y_final, err, ax=ax, **kwargs)
2022-03-15 10:57:58 +01:00
2022-04-07 16:34:37 +02:00
2022-03-15 10:57:58 +01:00
@wrap_plot
def plot_diff_vs_sigma(
x,
y,
reference,
ax=None,
label="",
transform=lambda y: y,
ecolor="yellow",
ylabel=None,
bath=None,
):
label = label + ", " if (len(label) > 0) else ""
y = y if bath is None else [(n, val[bath], err[bath]) for n, val, err in y]
ref_traj = transform(reference)
ref_err = np.abs(y[-1][2].real)
ax.fill_between(
x,
0,
ref_err,
color=ecolor,
2022-03-23 12:55:05 +01:00
label=rf"{label}$\sigma\, (N={y[-1][0]})$",
2022-03-15 10:57:58 +01:00
)
for n, val, err in y:
diff = np.abs(transform(val) - ref_traj)
within = (diff < ref_err).sum() / y[-1][2].size
not_last = n < y[-1][0]
ax.plot(
x,
diff,
2022-03-23 12:55:05 +01:00
label=rf"{label}$N={n}$ $\Delta<\sigma = {within * 100:.1f}\%$",
2022-03-15 10:57:58 +01:00
alpha=within if not_last else 1,
linestyle=":" if not_last else "-",
color=None if not_last else "red",
)
if ylabel is not None:
if ylabel[0] == "$":
ylabel = ylabel[1:-1]
else:
2022-03-23 12:55:05 +01:00
ylabel = rf"\text{{ {ylabel} }}"
2022-03-15 10:57:58 +01:00
2022-03-23 12:55:05 +01:00
ax.set_ylabel(rf"$|{{{ylabel}}}_{{\mathrm{{ref}}}}-{{{ylabel}}}_{{N_i}}|$")
2022-03-24 19:58:05 +01:00
2022-04-07 16:34:37 +02:00
def plot_interaction_consistency(
models, reference=None, label_fn=lambda model: f"$ω_c={model.ω_c:.2f}$", **kwargs
):
fig, ax = plt.subplots()
if reference:
with aux.get_data(reference) as data:
reference_energy = reference.interaction_energy(data, **kwargs)
for model in models:
with aux.get_data(model) as data:
energy = model.interaction_energy(data, **kwargs)
interaction_ref = model.interaction_energy_from_conservation(data, **kwargs)
diff = abs(interaction_ref - energy)
self_consistency = (diff.value < diff.σ).sum() / len(diff.value[0]) * 100
if reference:
diff = abs(interaction_ref - reference_energy)
final_consistency = (
(diff.value < diff.σ).sum() / len(diff.value[0]) * 100
)
2022-04-07 16:46:39 +02:00
_, _, (line, _) = plot_with_σ(
2022-04-07 16:34:37 +02:00
data.time[:],
energy,
ax=ax,
label=label_fn(model)
+ fr", (${self_consistency:.0f}\%$"
+ (fr", ${final_consistency:.0f}\%$)" if reference else ")"),
bath=0,
)
2022-04-07 16:46:39 +02:00
plot_with_σ(
2022-04-07 16:34:37 +02:00
data.time[:],
interaction_ref,
ax=ax,
linestyle="--",
bath=0,
2022-04-07 16:54:39 +02:00
color=lighten_color(line[0].get_color(), 0.8),
2022-04-07 16:34:37 +02:00
)
ax.set_xlabel("$τ$")
ax.set_ylabel(r"$\langle H_\mathrm{I}\rangle$")
ax.legend()
return fig, ax
def plot_interaction_consistency_development(
models, reference=None, label_fn=lambda model: f"$ω_c={model.ω_c:.2f}$", **kwargs
):
fig, ax = plt.subplots()
if reference:
with aux.get_data(reference) as data:
reference_energy = reference.interaction_energy(data, **kwargs)
for model in models:
with aux.get_data(model) as data:
interaction_ref = model.interaction_energy_from_conservation(data, **kwargs)
if reference:
diff = abs(interaction_ref - reference_energy)
else:
energy = model.interaction_energy(data, **kwargs)
diff = abs(interaction_ref - energy)
ns, values = [], []
for N, val, σ in diff:
ns.append(N)
values.append((val < σ).sum() / len(val[0]) * 100)
ax.plot(ns, values, linestyle="--", marker=".", label=label_fn(model))
ax.axhline(68, linestyle="-.", color="grey", alpha=0.5)
ax.set_xlabel("$N$")
ax.set_ylabel(("" if reference else "Self-") + r"Consistency [$\%$]")
ax.legend()
return fig, ax
def plot_flow_bcf(models, label_fn=lambda model: f"$ω_c={model.ω_c:.2f}$", **kwargs):
fig, ax = plt.subplots()
for model in models:
with aux.get_data(model) as data:
flow = model.bath_energy_flow(data, **kwargs)
2022-04-07 16:46:39 +02:00
_, _, (line, _) = plot_with_σ(
2022-04-07 16:34:37 +02:00
data.time[:],
flow,
ax=ax,
label=label_fn(model),
bath=0,
transform=lambda y: -y,
)
ax.plot(
data.time[:],
-model.L_expect * model.bcf_scale * model.bcf(data.time[:]).imag,
linestyle="--",
color=line[0].get_color(),
)
return fig, ax
def plot_energy_overview(model, bath=0, ensemble_args=None):
if not ensemble_args:
ensemble_args = {}
fig, ax = plt.subplots()
with aux.get_data(model) as data:
bath_energy = model.bath_energy(data, **ensemble_args)
interaction_energy = model.interaction_energy(data, **ensemble_args)
system_energy = model.system_energy(data, **ensemble_args)
flow = model.bath_energy_flow(data, **ensemble_args)
plot_with_σ(model.t, flow, bath=bath, ax=ax, label="Flow")
plot_with_σ(model.t, bath_energy, bath=bath, ax=ax, label="Bath")
plot_with_σ(model.t, system_energy, ax=ax, label="System")
plot_with_σ(
model.t,
interaction_energy,
ax=ax,
bath=0,
label="Interaction",
)
plot_with_σ(
model.t,
bath_energy + interaction_energy + system_energy,
bath=0,
ax=ax,
label="Total",
)
ax.plot(model.t, model.L.operator_norm(model.t))
return fig, ax
2022-03-24 19:58:05 +01:00
###############################################################################
# SIDE EFFECTS #
###############################################################################
MPL_RC = {
"font.family": "serif",
"text.usetex": False,
"pgf.rcfonts": False,
"lines.linewidth": 1,
"font.size": 10,
"axes.labelsize": 10,
"axes.titlesize": 8,
"font.size": 10,
"legend.fontsize": 8,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
}
@contextmanager
def hiro_style():
with plt.style.context("ggplot"):
with matplotlib.rc_context(MPL_RC):
yield True
plt.style.use("ggplot")
matplotlib.rcParams.update(MPL_RC)