mirror of
https://github.com/vale981/hiroplotutils
synced 2025-03-04 17:11:38 -05:00
add parallel plot evaluation code
This commit is contained in:
parent
ec86900d0e
commit
92ff585873
1 changed files with 104 additions and 0 deletions
|
@ -8,6 +8,8 @@ import subprocess
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
@ -248,3 +250,105 @@ def scientific_round(val, *err, retprec=False):
|
||||||
return (smart_round(val, prec), *smart_round(err, prec)[0], prec)
|
return (smart_round(val, prec), *smart_round(err, prec)[0], prec)
|
||||||
else:
|
else:
|
||||||
return (smart_round(val, prec), *smart_round(err, prec)[0])
|
return (smart_round(val, prec), *smart_round(err, prec)[0])
|
||||||
|
|
||||||
|
|
||||||
|
class PlotContainer:
|
||||||
|
"""A container for plots that can be executed in parallel."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
figdir: str | pathlib.Path,
|
||||||
|
sizes: dict[str, float],
|
||||||
|
default_ratio: float = 0.61803398876,
|
||||||
|
pyplot_config: Callable | None = None,
|
||||||
|
):
|
||||||
|
self._plots = []
|
||||||
|
self._figdir = figdir
|
||||||
|
self._sizes = sizes
|
||||||
|
self._default_ratio: float = default_ratio
|
||||||
|
self._pyplot_config = pyplot_config or (lambda: None)
|
||||||
|
|
||||||
|
def _save_fig(self, fig, size: str | tuple[str, float], num: int):
|
||||||
|
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
|
||||||
|
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
|
||||||
|
|
||||||
|
fig.set_size_inches(horizontal, vertical * horizontal)
|
||||||
|
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir)
|
||||||
|
|
||||||
|
def register(self, size: str | tuple[str, float], *args: dict):
|
||||||
|
"""Registers a plot to be executed."""
|
||||||
|
|
||||||
|
if hasattr(sys, "ps1"):
|
||||||
|
self._pyplot_config()
|
||||||
|
return lambda f: f
|
||||||
|
|
||||||
|
if len(args) == 1 and isinstance(args[0], Callable):
|
||||||
|
return self.register(next(iter(self._sizes.keys())))(args[0])
|
||||||
|
|
||||||
|
if len(args) == 0:
|
||||||
|
args = ({},)
|
||||||
|
|
||||||
|
def decorator(f):
|
||||||
|
for keywords in args:
|
||||||
|
logging.info("Registered plot", f, keywords)
|
||||||
|
self._plots.append(
|
||||||
|
(
|
||||||
|
lambda i: [
|
||||||
|
self._pyplot_config(),
|
||||||
|
self._save_fig(f(**keywords), size, i),
|
||||||
|
],
|
||||||
|
f,
|
||||||
|
keywords,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def execute_plots(self, *args, **kwargs):
|
||||||
|
"""Executes the given list of plots in parallel."""
|
||||||
|
|
||||||
|
if hasattr(sys, "ps1"):
|
||||||
|
return
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--n_jobs", type=int, default=-1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list", help="List all available plots", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only",
|
||||||
|
type=int,
|
||||||
|
help="Only execute the plot with the given index.",
|
||||||
|
default=None,
|
||||||
|
choices=range(len(self._plots)),
|
||||||
|
nargs="+",
|
||||||
|
)
|
||||||
|
cmd_args = parser.parse_args()
|
||||||
|
|
||||||
|
if cmd_args.list:
|
||||||
|
format = "{:2d} {:<30} "
|
||||||
|
for i, plot in enumerate(self._plots):
|
||||||
|
print(format.format(i + 1, plot[1].__name__), end="")
|
||||||
|
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
only = (
|
||||||
|
range(len(self._plots))
|
||||||
|
if cmd_args.only is None
|
||||||
|
else [i - 1 for i in cmd_args.only]
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_plots = [(i + 1, self._plots[i]) for i in only]
|
||||||
|
|
||||||
|
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
|
||||||
|
kwargs["n_jobs"] = cmd_args.n_jobs
|
||||||
|
|
||||||
|
if "backend" not in kwargs:
|
||||||
|
kwargs["backend"] = "loky"
|
||||||
|
|
||||||
|
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)
|
||||||
|
|
Loading…
Add table
Reference in a new issue