add parallel plot evaluation code

This commit is contained in:
Valentin Boettcher 2025-01-29 16:07:25 -05:00
parent ec86900d0e
commit 92ff585873

View file

@ -8,6 +8,8 @@ import subprocess
import pathlib
import sys
import numpy as np
from joblib import Parallel, delayed
import logging
P = ParamSpec("P")
@ -248,3 +250,105 @@ def scientific_round(val, *err, retprec=False):
return (smart_round(val, prec), *smart_round(err, prec)[0], prec)
else:
return (smart_round(val, prec), *smart_round(err, prec)[0])
class PlotContainer:
"""A container for plots that can be executed in parallel."""
def __init__(
self,
figdir: str | pathlib.Path,
sizes: dict[str, float],
default_ratio: float = 0.61803398876,
pyplot_config: Callable | None = None,
):
self._plots = []
self._figdir = figdir
self._sizes = sizes
self._default_ratio: float = default_ratio
self._pyplot_config = pyplot_config or (lambda: None)
def _save_fig(self, fig, size: str | tuple[str, float], num: int):
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
fig.set_size_inches(horizontal, vertical * horizontal)
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir)
def register(self, size: str | tuple[str, float], *args: dict):
"""Registers a plot to be executed."""
if hasattr(sys, "ps1"):
self._pyplot_config()
return lambda f: f
if len(args) == 1 and isinstance(args[0], Callable):
return self.register(next(iter(self._sizes.keys())))(args[0])
if len(args) == 0:
args = ({},)
def decorator(f):
for keywords in args:
logging.info("Registered plot", f, keywords)
self._plots.append(
(
lambda i: [
self._pyplot_config(),
self._save_fig(f(**keywords), size, i),
],
f,
keywords,
)
)
return f
return decorator
def execute_plots(self, *args, **kwargs):
"""Executes the given list of plots in parallel."""
if hasattr(sys, "ps1"):
return
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n_jobs", type=int, default=-1)
parser.add_argument(
"--list", help="List all available plots", action="store_true"
)
parser.add_argument(
"--only",
type=int,
help="Only execute the plot with the given index.",
default=None,
choices=range(len(self._plots)),
nargs="+",
)
cmd_args = parser.parse_args()
if cmd_args.list:
format = "{:2d} {:<30} "
for i, plot in enumerate(self._plots):
print(format.format(i + 1, plot[1].__name__), end="")
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
return
only = (
range(len(self._plots))
if cmd_args.only is None
else [i - 1 for i in cmd_args.only]
)
selected_plots = [(i + 1, self._plots[i]) for i in only]
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
kwargs["n_jobs"] = cmd_args.n_jobs
if "backend" not in kwargs:
kwargs["backend"] = "loky"
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)