fix plot numbering

This commit is contained in:
Valentin Boettcher 2025-01-30 10:52:12 -05:00
parent 5a6231e2e6
commit 14969bcef2
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -10,7 +10,7 @@ import sys
import numpy as np import numpy as np
from joblib import Parallel, delayed from joblib import Parallel, delayed
import logging import logging
import itertools
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@ -286,12 +286,16 @@ class PlotContainer:
self._default_ratio: float = default_ratio self._default_ratio: float = default_ratio
self._pyplot_config = pyplot_config or (lambda: None) self._pyplot_config = pyplot_config or (lambda: None)
def _save_fig(self, fig, size: str | tuple[str, float], num: int): def _save_fig(self, fig, size: str | tuple[str, float], num: tuple[int, int]):
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size] horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
fig.set_size_inches(horizontal, vertical * horizontal) fig.set_size_inches(horizontal, vertical * horizontal)
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir) save_figure(
fig,
f"{num[0]:03}_{num[1]:03}_" + fig.get_label(),
directory=str(self._figdir),
)
def register(self, size: str | tuple[str, float], *args: dict): def register(self, size: str | tuple[str, float], *args: dict):
"""Registers a plot to be executed.""" """Registers a plot to be executed."""
@ -307,19 +311,27 @@ class PlotContainer:
args = ({},) args = ({},)
def decorator(f): def decorator(f):
for keywords in args: plots = []
plot_index = len(self._plots) + 1
for sub_index, keywords in enumerate(args):
logging.info("Registered plot", f, keywords) logging.info("Registered plot", f, keywords)
self._plots.append( plots.append(
( (
lambda i: [ lambda: [
self._pyplot_config(), self._pyplot_config(),
self._save_fig(f(**keywords), size, i), self._save_fig(
], f(**keywords),
size,
(len(self._plots) + 1, sub_index + 1),
),
]
and None,
f, f,
keywords, keywords,
) )
) )
self._plots.append(plots)
return f return f
return decorator return decorator
@ -330,6 +342,7 @@ class PlotContainer:
if hasattr(sys, "ps1"): if hasattr(sys, "ps1"):
return return
flattened = list(itertools.chain.from_iterable(self._plots))
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -342,31 +355,36 @@ class PlotContainer:
type=int, type=int,
help="Only execute the plot with the given index.", help="Only execute the plot with the given index.",
default=None, default=None,
choices=range(len(self._plots)), choices=[i + 1 for i in range(len(flattened))],
nargs="+", nargs="+",
) )
cmd_args = parser.parse_args() cmd_args = parser.parse_args()
if cmd_args.list: if cmd_args.list:
format = "{:2d} {:<30} " print("hi")
for i, plot in enumerate(self._plots): format = "{:2d} [{:2d} {:2d}] {:<30} "
print(format.format(i + 1, plot[1].__name__), end="") total = 1
print(" ".join(f"{k}={v}" for k, v in plot[2].items())) for i, plot_group in enumerate(self._plots):
for sub_index, plot in enumerate(plot_group):
print(
format.format(total, i + 1, sub_index + 1, plot[1].__name__),
end="",
)
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
total += 1
return return
only = ( only = (
range(len(self._plots)) range(len(flattened))
if cmd_args.only is None if cmd_args.only is None
else [i - 1 for i in cmd_args.only] else [i - 1 for i in cmd_args.only]
) )
selected_plots = [(i + 1, self._plots[i]) for i in only]
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1: if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
kwargs["n_jobs"] = cmd_args.n_jobs kwargs["n_jobs"] = cmd_args.n_jobs
if "backend" not in kwargs: if "backend" not in kwargs:
kwargs["backend"] = "loky" kwargs["backend"] = "loky"
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots) Parallel(*args, **kwargs)(delayed(flattened[i][0])() for i in only)