diff --git a/hiroplotutils/__init__.py b/hiroplotutils/__init__.py index 4eee61e..94aa999 100644 --- a/hiroplotutils/__init__.py +++ b/hiroplotutils/__init__.py @@ -8,6 +8,8 @@ import subprocess import pathlib import sys import numpy as np +from joblib import Parallel, delayed +import logging P = ParamSpec("P") @@ -248,3 +250,105 @@ def scientific_round(val, *err, retprec=False): return (smart_round(val, prec), *smart_round(err, prec)[0], prec) else: return (smart_round(val, prec), *smart_round(err, prec)[0]) + + +class PlotContainer: + """A container for plots that can be executed in parallel.""" + + def __init__( + self, + figdir: str | pathlib.Path, + sizes: dict[str, float], + default_ratio: float = 0.61803398876, + pyplot_config: Callable | None = None, + ): + self._plots = [] + self._figdir = figdir + self._sizes = sizes + self._default_ratio: float = default_ratio + self._pyplot_config = pyplot_config or (lambda: None) + + def _save_fig(self, fig, size: str | tuple[str, float], num: int): + horizontal = self._sizes[size[0] if isinstance(size, tuple) else size] + vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio + + fig.set_size_inches(horizontal, vertical * horizontal) + save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir) + + def register(self, size: str | tuple[str, float], *args: dict): + """Registers a plot to be executed.""" + + if hasattr(sys, "ps1"): + self._pyplot_config() + return lambda f: f + + if len(args) == 1 and isinstance(args[0], Callable): + return self.register(next(iter(self._sizes.keys())))(args[0]) + + if len(args) == 0: + args = ({},) + + def decorator(f): + for keywords in args: + logging.info("Registered plot", f, keywords) + self._plots.append( + ( + lambda i: [ + self._pyplot_config(), + self._save_fig(f(**keywords), size, i), + ], + f, + keywords, + ) + ) + + return f + + return decorator + + def execute_plots(self, *args, **kwargs): + """Executes the given list of plots in parallel.""" + + if hasattr(sys, "ps1"): + return + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--n_jobs", type=int, default=-1) + parser.add_argument( + "--list", help="List all available plots", action="store_true" + ) + parser.add_argument( + "--only", + type=int, + help="Only execute the plot with the given index.", + default=None, + choices=range(len(self._plots)), + nargs="+", + ) + cmd_args = parser.parse_args() + + if cmd_args.list: + format = "{:2d} {:<30} " + for i, plot in enumerate(self._plots): + print(format.format(i + 1, plot[1].__name__), end="") + print(" ".join(f"{k}={v}" for k, v in plot[2].items())) + + return + + only = ( + range(len(self._plots)) + if cmd_args.only is None + else [i - 1 for i in cmd_args.only] + ) + + selected_plots = [(i + 1, self._plots[i]) for i in only] + + if "n_jobs" not in kwargs or cmd_args.n_jobs != -1: + kwargs["n_jobs"] = cmd_args.n_jobs + + if "backend" not in kwargs: + kwargs["backend"] = "loky" + + Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)