fix plot numbering

This commit is contained in:
Valentin Boettcher 2025-01-30 10:52:12 -05:00
parent 5a6231e2e6
commit 14969bcef2
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -10,7 +10,7 @@ import sys
import numpy as np
from joblib import Parallel, delayed
import logging
import itertools
P = ParamSpec("P")
R = TypeVar("R")
@ -286,12 +286,16 @@ class PlotContainer:
self._default_ratio: float = default_ratio
self._pyplot_config = pyplot_config or (lambda: None)
def _save_fig(self, fig, size: str | tuple[str, float], num: int):
def _save_fig(self, fig, size: str | tuple[str, float], num: tuple[int, int]):
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
fig.set_size_inches(horizontal, vertical * horizontal)
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir)
save_figure(
fig,
f"{num[0]:03}_{num[1]:03}_" + fig.get_label(),
directory=str(self._figdir),
)
def register(self, size: str | tuple[str, float], *args: dict):
"""Registers a plot to be executed."""
@ -307,19 +311,27 @@ class PlotContainer:
args = ({},)
def decorator(f):
for keywords in args:
plots = []
plot_index = len(self._plots) + 1
for sub_index, keywords in enumerate(args):
logging.info("Registered plot", f, keywords)
self._plots.append(
plots.append(
(
lambda i: [
lambda: [
self._pyplot_config(),
self._save_fig(f(**keywords), size, i),
],
self._save_fig(
f(**keywords),
size,
(len(self._plots) + 1, sub_index + 1),
),
]
and None,
f,
keywords,
)
)
self._plots.append(plots)
return f
return decorator
@ -330,6 +342,7 @@ class PlotContainer:
if hasattr(sys, "ps1"):
return
flattened = list(itertools.chain.from_iterable(self._plots))
import argparse
parser = argparse.ArgumentParser()
@ -342,31 +355,36 @@ class PlotContainer:
type=int,
help="Only execute the plot with the given index.",
default=None,
choices=range(len(self._plots)),
choices=[i + 1 for i in range(len(flattened))],
nargs="+",
)
cmd_args = parser.parse_args()
if cmd_args.list:
format = "{:2d} {:<30} "
for i, plot in enumerate(self._plots):
print(format.format(i + 1, plot[1].__name__), end="")
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
print("hi")
format = "{:2d} [{:2d} {:2d}] {:<30} "
total = 1
for i, plot_group in enumerate(self._plots):
for sub_index, plot in enumerate(plot_group):
print(
format.format(total, i + 1, sub_index + 1, plot[1].__name__),
end="",
)
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
total += 1
return
only = (
range(len(self._plots))
range(len(flattened))
if cmd_args.only is None
else [i - 1 for i in cmd_args.only]
)
selected_plots = [(i + 1, self._plots[i]) for i in only]
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
kwargs["n_jobs"] = cmd_args.n_jobs
if "backend" not in kwargs:
kwargs["backend"] = "loky"
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)
Parallel(*args, **kwargs)(delayed(flattened[i][0])() for i in only)