master-thesis/python/energy_flow_proper/figsaver.py
2022-06-09 18:14:04 +02:00

586 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from contextlib import contextmanager
import utilities as ut
from hopsflow.util import EnsembleValue
from hops.util.utilities import (
relative_entropy,
relative_entropy_single,
entropy,
trace_distance,
)
try:
import hiro_models.model_auxiliary as aux
except:
aux = None
fig_path = Path(os.getcwd()) / "figures"
val_path = Path(os.getcwd()) / "values"
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i / inch for i in tupl[0])
else:
return tuple(i / inch for i in tupl)
def export_fig(name, fig=None):
fig_path.mkdir(parents=True, exist_ok=True)
if fig is None:
fig = plt.gcf()
fig.tight_layout()
fig.canvas.draw()
fig.savefig(fig_path / f"{name}.pdf")
fig.savefig(fig_path / f"{name}.svg")
fig.savefig(fig_path / f"{name}.pgf")
return fig
def scientific_round(val, *err, retprec=False):
"""Scientifically rounds the values to the given errors."""
val, err = np.asarray(val), np.asarray(err)
if len(err.shape) == 1:
err = np.array([err])
err = err.T
err = err.T
if err.size == 1 and val.size > 1:
err = np.ones_like(val) * err
if len(err.shape) == 0:
err = np.array([err])
if val.size == 1 and err.shape[0] > 1:
val = np.ones_like(err) * val
i = np.floor(np.log10(err))
first_digit = (err // 10 ** i).astype(int)
prec = (-i + np.ones_like(err) * (first_digit <= 3)).astype(int)
prec = np.max(prec, axis=1)
def smart_round(value, precision):
value = np.round(value, precision)
if precision <= 0:
value = value.astype(int)
return value
if val.size > 1:
rounded = np.empty_like(val)
rounded_err = np.empty_like(err)
for n, (value, error, precision) in enumerate(zip(val, err, prec)):
rounded[n] = smart_round(value, precision)
rounded_err[n] = smart_round(error, precision)
if retprec:
return rounded, rounded_err, prec
else:
return rounded, rounded_err
else:
prec = prec[0]
if retprec:
return (smart_round(val, prec), *smart_round(err, prec)[0], prec)
else:
return (smart_round(val, prec), *smart_round(err, prec)[0])
def tex_value(val, err=None, unit=None, prefix="", suffix="", prec=0, save=None):
"""Generates LaTeX output of a value with units and error."""
if err:
val, err, prec = scientific_round(val, err, retprec=True)
else:
val = np.round(val, prec)
if prec == 0:
val = int(val)
if err:
err = int(err)
val_string = rf"{val:.{prec}f}" if prec > 0 else str(val)
if err:
val_string += rf"\pm {err:.{prec}f}" if prec > 0 else str(err)
ret_string = r"\(" + prefix
if unit is None:
ret_string += val_string
else:
ret_string += rf"\SI{{{val_string}}}{{{unit}}}"
ret_string += suffix + r"\)"
if save is not None:
val_path.mkdir(parents=True, exist_ok=True)
with open(val_path / f"{save}.tex", "w") as f:
f.write(ret_string)
return ret_string
###############################################################################
# Plot Porn #
###############################################################################
def wrap_plot(f):
def wrapped(*args, ax=None, setup_function=plt.subplots, **kwargs):
fig = None
if not ax:
fig, ax = setup_function()
ret_val = f(*args, ax=ax, **kwargs)
return (fig, ax, ret_val) if ret_val else (fig, ax)
return wrapped
def get_figsize(
columnwidth: float,
wf: float = 0.5,
hf: float = (5.0 ** 0.5 - 1.0) / 2.0,
) -> tuple[float, float]:
"""
:param wf: Width fraction in columnwidth units.
:param hf: Weight fraction in columnwidth units.
Set by default to golden ratio.
:param columnwidth: width of the column in latex.
Get this from LaTeX using \showthe\columnwidth
:returns: The ``[fig_width,fig_height]`` that should be given to
matplotlib.
"""
fig_width_pt = columnwidth * wf
inches_per_pt = 1.0 / 72.27 # Convert pt to inch
fig_width = fig_width_pt * inches_per_pt # width in inches
fig_height = fig_width * hf # height in inches
return [fig_width, fig_height]
@wrap_plot
def plot_complex(x, y, *args, ax=None, label="", **kwargs):
label = label + ", " if (len(label) > 0) else ""
ax.plot(x, y.real, *args, label=f"{label}real part", **kwargs)
ax.plot(x, y.imag, *args, label=f"{label}imag part", **kwargs)
ax.legend()
@wrap_plot
def plot_convergence(
x,
y,
ax=None,
label="",
transform=lambda y: y,
slice=None,
linestyle="-",
bath=None,
):
label = label + ", " if (len(label) > 0) else ""
slice = (0, -1) if not slice else slice
for i in range(len(y) - 1):
current_value = y[i]
consistency = y.consistency(current_value)
n, val, _ = current_value.final_aggregate
line = ax.plot(
x,
transform(val[bath] if bath is not None else val),
label=f"{label}$N={n}, {consistency:.0f}\%$",
alpha=n / y.N,
linestyle=":",
)
err = (y.σ[bath] if bath is not None else y.σ).real
y_final = transform(y.value[bath] if bath is not None else y.value)
ax.fill_between(
x,
y_final + err,
y_final - err,
color=lighten_color(line[0].get_color(), 0.5),
alpha=0.5,
label=f"{label}($σ$)",
)
ax.plot(x, y_final, label=f"{label}$N={y.N}$", linestyle=linestyle, color="red")
return None
def lighten_color(color, amount=0.5):
"""
Lightens the given color by multiplying (1-luminosity) by the given amount.
Input can be matplotlib color string, hex string, or RGB tuple.
Examples:
>> lighten_color('g', 0.3)
>> lighten_color('#F034A3', 0.6)
>> lighten_color((.3,.55,.1), 0.5)
"""
import matplotlib.colors as mc
import colorsys
try:
c = mc.cnames[color]
except:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
def fancy_error(x, y, err, ax, **kwargs):
line = ax.plot(
x,
y,
**kwargs,
)
err = ax.fill_between(
x,
y + err,
y - err,
color=lighten_color(line[0].get_color(), 0.5),
alpha=0.5,
)
return line, err
@wrap_plot
def plot_with_σ(
x,
y,
ax=None,
transform=lambda y: y,
bath=None,
strobe_frequency=None,
strobe_tolerance=1e-3,
**kwargs,
):
err = (y.σ[bath] if bath is not None else y.σ).real
y_final = transform(y.value[bath] if bath is not None else y.value)
strobe_mode = strobe_frequency is not None
strobe_indices = None
strobe_times = None
strobe_style = dict(linestyle="none", marker="o", markersize=2) | kwargs
if strobe_mode:
strobe_times, strobe_indices = ut.strobe_times(
x, strobe_frequency, strobe_tolerance
)
return ax.errorbar(
strobe_times,
y_final[strobe_indices],
err[strobe_indices],
**strobe_style,
)
return fancy_error(x, y_final, err, ax=ax, **kwargs)
@wrap_plot
def plot_diff_vs_sigma(
x,
y,
reference,
ax=None,
label="",
transform=lambda y: y,
ecolor="yellow",
ealpha=0.5,
ylabel=None,
bath=None,
):
label = label + ", " if (len(label) > 0) else ""
if bath is not None:
y = y.for_bath(bath)
reference = reference.for_bath(bath)
ax.fill_between(
x,
0,
reference.σ,
color=ecolor,
alpha=ealpha,
label=rf"{label}$\sigma\, (N={reference.N})$",
)
for i in range(len(y)):
current = y[i]
not_last = current.N < y[-1].N
consistency = current.consistency(reference)
diff = abs(current - reference)
ax.plot(
x,
diff.value,
label=rf"{label}$N={current.N}$ $({consistency:.1f}\%)$",
alpha=consistency / 100 if not_last else 1,
linestyle=":" if not_last else "-",
color=None if not_last else "red",
)
if ylabel is not None:
if ylabel[0] == "$":
ylabel = ylabel[1:-1]
else:
ylabel = rf"\text{{ {ylabel} }}"
ax.set_ylabel(rf"$|{{{ylabel}}}_{{\mathrm{{ref}}}}-{{{ylabel}}}_{{N_i}}|$")
def plot_interaction_consistency(
models, reference=None, label_fn=lambda model: f"$ω_c={model.ω_c:.2f}$", **kwargs
):
fig, ax = plt.subplots()
if reference:
with aux.get_data(reference) as data:
reference_energy = reference.interaction_energy(data, **kwargs)
for model in models:
with aux.get_data(model) as data:
energy = model.interaction_energy(data, **kwargs)
interaction_ref = model.interaction_energy_from_conservation(data, **kwargs)
diff = abs(interaction_ref - energy)
self_consistency = (diff.value < diff.σ).sum() / len(diff.value[0]) * 100
if reference:
diff = abs(interaction_ref - reference_energy)
final_consistency = (
(diff.value < diff.σ).sum() / len(diff.value[0]) * 100
)
_, _, (line, _) = plot_with_σ(
data.time[:],
energy,
ax=ax,
label=label_fn(model)
+ fr", (${self_consistency:.0f}\%$"
+ (fr", ${final_consistency:.0f}\%$)" if reference else ")"),
bath=0,
)
plot_with_σ(
data.time[:],
interaction_ref,
ax=ax,
linestyle="--",
bath=0,
color=lighten_color(line[0].get_color(), 0.8),
)
ax.set_xlabel("$τ$")
ax.set_ylabel(r"$\langle H_\mathrm{I}\rangle$")
ax.legend()
return fig, ax
def plot_interaction_consistency_development(
models, reference=None, label_fn=lambda model: f"$ω_c={model.ω_c:.2f}$", **kwargs
):
fig, (ax, ax2, ax3) = plt.subplots(nrows=1, ncols=3)
ax2.set_xscale("log")
ax2.set_yscale("log")
ax3.set_xscale("log")
ax3.set_yscale("log")
if reference is not None:
with aux.get_data(reference) as data:
reference_energy = reference.interaction_energy(data, **kwargs)
for model in models:
with aux.get_data(model) as data:
interaction_ref = model.interaction_energy_from_conservation(data, **kwargs)
if reference is not None:
energy = reference_energy
else:
energy = model.interaction_energy(data, **kwargs)
diff = abs(interaction_ref - energy)
ns, values, σs = [], [], []
for N, val, σ in diff.aggregate_iterator:
ns.append(N)
values.append((val < σ).sum() / len(val[0]) * 100)
σs.append(σ.max())
σ_ref, σ_int = [], []
for _, _, σ in interaction_ref.aggregate_iterator:
σ_ref.append(σ.max())
for _, _, σ in energy.aggregate_iterator:
σ_int.append(σ.max())
ax.plot(
ns,
values,
linestyle="--",
marker=".",
markersize=2,
label=label_fn(model),
)
ax2.plot(ns, σs, label=label_fn(model))
ax3.plot(
ns,
σ_ref,
label=label_fn(model) + " (from conservation)",
linestyle="--",
)
ax3.plot(ns, σ_int, label=label_fn(model) + " (direct)")
ax.axhline(68, linestyle="-.", color="grey", alpha=0.5)
ax.set_xlabel("$N$")
ax2.set_xlabel("$N$")
ax3.set_xlabel("$N$")
ax2.set_ylabel("Maximum Allowed Deviation")
ax3.set_ylabel("Maximum $σ$")
ax.set_ylabel(("" if reference else "Self-") + r"Consistency [$\%$]")
ax.legend()
ax2.legend()
ax3.legend()
return fig, [ax, ax2, ax3]
def plot_flow_bcf(models, label_fn=lambda model: f"$ω_c={model.ω_c:.2f}$", **kwargs):
fig, ax = plt.subplots()
for model in models:
with aux.get_data(model) as data:
flow = model.bath_energy_flow(data, **kwargs)
_, _, (line, _) = plot_with_σ(
data.time[:],
flow,
ax=ax,
label=label_fn(model),
bath=0,
transform=lambda y: -y,
)
ax.plot(
data.time[:],
-model.L_expect * model.bcf_scale * model.bcf(data.time[:]).imag,
linestyle="--",
color=line[0].get_color(),
)
return fig, ax
def plot_energy_overview(model, ensemble_args=None, **kwargs):
if not ensemble_args:
ensemble_args = {}
fig, ax = plt.subplots()
with aux.get_data(model) as data:
system_energy = model.system_energy(data, **ensemble_args)
bath_energy = model.bath_energy(data, **ensemble_args)
interaction_energy = model.interaction_energy(data, **ensemble_args)
flow = model.bath_energy_flow(data, **ensemble_args)
plot_with_σ(model.t, system_energy, ax=ax, label="System", **kwargs)
num_baths = flow.num_baths
for bath in range(num_baths):
plot_with_σ(model.t, flow, bath=bath, ax=ax, label=f"Flow {bath}", **kwargs)
plot_with_σ(
model.t, bath_energy, bath=bath, ax=ax, label=f"Bath {bath}", **kwargs
)
plot_with_σ(
model.t,
interaction_energy,
ax=ax,
bath=bath,
label=f"Interaction {bath}",
**kwargs,
)
total = model.total_energy(data, **ensemble_args)
plot_with_σ(
model.t,
total,
ax=ax,
label="Total",
# linestyle="--",
**kwargs,
)
return (fig, ax)
@wrap_plot
def plot_coherences(model, ax=None):
with aux.get_data(model) as data:
plot_with_σ(
model.t,
EnsembleValue(
(
0,
np.abs(np.array(data.rho_t_accum.mean)[:, 0, 1]),
np.array(data.rho_t_accum.ensemble_std)[:, 0, 1],
)
),
ax=ax,
)
@wrap_plot
def plot_distance_measures(model, strobe_indices, ax=None):
with aux.get_data(model) as data:
plot_with_σ(
model.t, EnsembleValue(relative_entropy(data, strobe_indices[-1])), ax=ax
)
plot_with_σ(
model.t, EnsembleValue(trace_distance(data, strobe_indices[-1])), ax=ax
)
###############################################################################
# SIDE EFFECTS #
###############################################################################
MPL_RC = {
"font.family": "serif",
"text.usetex": False,
"pgf.rcfonts": False,
"lines.linewidth": 1,
"font.size": 10,
"axes.labelsize": 10,
"axes.titlesize": 8,
"font.size": 10,
"legend.fontsize": 8,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
}
@contextmanager
def hiro_style():
with plt.style.context("ggplot"):
with matplotlib.rc_context(MPL_RC):
yield True
plt.style.use("ggplot")
matplotlib.rcParams.update(MPL_RC)