mirror of
https://github.com/vale981/hiroplotutils
synced 2025-03-04 09:01:40 -05:00
fix plot numbering
This commit is contained in:
parent
5a6231e2e6
commit
14969bcef2
1 changed files with 35 additions and 17 deletions
|
@ -10,7 +10,7 @@ import sys
|
|||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
import logging
|
||||
|
||||
import itertools
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
@ -286,12 +286,16 @@ class PlotContainer:
|
|||
self._default_ratio: float = default_ratio
|
||||
self._pyplot_config = pyplot_config or (lambda: None)
|
||||
|
||||
def _save_fig(self, fig, size: str | tuple[str, float], num: int):
|
||||
def _save_fig(self, fig, size: str | tuple[str, float], num: tuple[int, int]):
|
||||
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
|
||||
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
|
||||
|
||||
fig.set_size_inches(horizontal, vertical * horizontal)
|
||||
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir)
|
||||
save_figure(
|
||||
fig,
|
||||
f"{num[0]:03}_{num[1]:03}_" + fig.get_label(),
|
||||
directory=str(self._figdir),
|
||||
)
|
||||
|
||||
def register(self, size: str | tuple[str, float], *args: dict):
|
||||
"""Registers a plot to be executed."""
|
||||
|
@ -307,19 +311,27 @@ class PlotContainer:
|
|||
args = ({},)
|
||||
|
||||
def decorator(f):
|
||||
for keywords in args:
|
||||
plots = []
|
||||
plot_index = len(self._plots) + 1
|
||||
for sub_index, keywords in enumerate(args):
|
||||
logging.info("Registered plot", f, keywords)
|
||||
self._plots.append(
|
||||
plots.append(
|
||||
(
|
||||
lambda i: [
|
||||
lambda: [
|
||||
self._pyplot_config(),
|
||||
self._save_fig(f(**keywords), size, i),
|
||||
],
|
||||
self._save_fig(
|
||||
f(**keywords),
|
||||
size,
|
||||
(len(self._plots) + 1, sub_index + 1),
|
||||
),
|
||||
]
|
||||
and None,
|
||||
f,
|
||||
keywords,
|
||||
)
|
||||
)
|
||||
|
||||
self._plots.append(plots)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
@ -330,6 +342,7 @@ class PlotContainer:
|
|||
if hasattr(sys, "ps1"):
|
||||
return
|
||||
|
||||
flattened = list(itertools.chain.from_iterable(self._plots))
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -342,31 +355,36 @@ class PlotContainer:
|
|||
type=int,
|
||||
help="Only execute the plot with the given index.",
|
||||
default=None,
|
||||
choices=range(len(self._plots)),
|
||||
choices=[i + 1 for i in range(len(flattened))],
|
||||
nargs="+",
|
||||
)
|
||||
cmd_args = parser.parse_args()
|
||||
|
||||
if cmd_args.list:
|
||||
format = "{:2d} {:<30} "
|
||||
for i, plot in enumerate(self._plots):
|
||||
print(format.format(i + 1, plot[1].__name__), end="")
|
||||
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
|
||||
print("hi")
|
||||
format = "{:2d} [{:2d} {:2d}] {:<30} "
|
||||
total = 1
|
||||
for i, plot_group in enumerate(self._plots):
|
||||
for sub_index, plot in enumerate(plot_group):
|
||||
print(
|
||||
format.format(total, i + 1, sub_index + 1, plot[1].__name__),
|
||||
end="",
|
||||
)
|
||||
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
|
||||
total += 1
|
||||
|
||||
return
|
||||
|
||||
only = (
|
||||
range(len(self._plots))
|
||||
range(len(flattened))
|
||||
if cmd_args.only is None
|
||||
else [i - 1 for i in cmd_args.only]
|
||||
)
|
||||
|
||||
selected_plots = [(i + 1, self._plots[i]) for i in only]
|
||||
|
||||
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
|
||||
kwargs["n_jobs"] = cmd_args.n_jobs
|
||||
|
||||
if "backend" not in kwargs:
|
||||
kwargs["backend"] = "loky"
|
||||
|
||||
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)
|
||||
Parallel(*args, **kwargs)(delayed(flattened[i][0])() for i in only)
|
||||
|
|
Loading…
Add table
Reference in a new issue