mirror of
https://github.com/vale981/hiroplotutils
synced 2025-03-04 09:01:40 -05:00
add parallel plot evaluation code
This commit is contained in:
parent
ec86900d0e
commit
92ff585873
1 changed files with 104 additions and 0 deletions
|
@ -8,6 +8,8 @@ import subprocess
|
|||
import pathlib
|
||||
import sys
|
||||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
import logging
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
@ -248,3 +250,105 @@ def scientific_round(val, *err, retprec=False):
|
|||
return (smart_round(val, prec), *smart_round(err, prec)[0], prec)
|
||||
else:
|
||||
return (smart_round(val, prec), *smart_round(err, prec)[0])
|
||||
|
||||
|
||||
class PlotContainer:
|
||||
"""A container for plots that can be executed in parallel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
figdir: str | pathlib.Path,
|
||||
sizes: dict[str, float],
|
||||
default_ratio: float = 0.61803398876,
|
||||
pyplot_config: Callable | None = None,
|
||||
):
|
||||
self._plots = []
|
||||
self._figdir = figdir
|
||||
self._sizes = sizes
|
||||
self._default_ratio: float = default_ratio
|
||||
self._pyplot_config = pyplot_config or (lambda: None)
|
||||
|
||||
def _save_fig(self, fig, size: str | tuple[str, float], num: int):
|
||||
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
|
||||
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
|
||||
|
||||
fig.set_size_inches(horizontal, vertical * horizontal)
|
||||
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir)
|
||||
|
||||
def register(self, size: str | tuple[str, float], *args: dict):
|
||||
"""Registers a plot to be executed."""
|
||||
|
||||
if hasattr(sys, "ps1"):
|
||||
self._pyplot_config()
|
||||
return lambda f: f
|
||||
|
||||
if len(args) == 1 and isinstance(args[0], Callable):
|
||||
return self.register(next(iter(self._sizes.keys())))(args[0])
|
||||
|
||||
if len(args) == 0:
|
||||
args = ({},)
|
||||
|
||||
def decorator(f):
|
||||
for keywords in args:
|
||||
logging.info("Registered plot", f, keywords)
|
||||
self._plots.append(
|
||||
(
|
||||
lambda i: [
|
||||
self._pyplot_config(),
|
||||
self._save_fig(f(**keywords), size, i),
|
||||
],
|
||||
f,
|
||||
keywords,
|
||||
)
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
def execute_plots(self, *args, **kwargs):
|
||||
"""Executes the given list of plots in parallel."""
|
||||
|
||||
if hasattr(sys, "ps1"):
|
||||
return
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||
parser.add_argument(
|
||||
"--list", help="List all available plots", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
type=int,
|
||||
help="Only execute the plot with the given index.",
|
||||
default=None,
|
||||
choices=range(len(self._plots)),
|
||||
nargs="+",
|
||||
)
|
||||
cmd_args = parser.parse_args()
|
||||
|
||||
if cmd_args.list:
|
||||
format = "{:2d} {:<30} "
|
||||
for i, plot in enumerate(self._plots):
|
||||
print(format.format(i + 1, plot[1].__name__), end="")
|
||||
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
|
||||
|
||||
return
|
||||
|
||||
only = (
|
||||
range(len(self._plots))
|
||||
if cmd_args.only is None
|
||||
else [i - 1 for i in cmd_args.only]
|
||||
)
|
||||
|
||||
selected_plots = [(i + 1, self._plots[i]) for i in only]
|
||||
|
||||
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
|
||||
kwargs["n_jobs"] = cmd_args.n_jobs
|
||||
|
||||
if "backend" not in kwargs:
|
||||
kwargs["backend"] = "loky"
|
||||
|
||||
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)
|
||||
|
|
Loading…
Add table
Reference in a new issue