From 14969bcef2e621785ab75733354faabc14f5b4de Mon Sep 17 00:00:00 2001 From: Valentin Boettcher Date: Thu, 30 Jan 2025 10:52:12 -0500 Subject: [PATCH] fix plot numbering --- hiroplotutils/__init__.py | 52 ++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/hiroplotutils/__init__.py b/hiroplotutils/__init__.py index 372ffaf..b0805c8 100644 --- a/hiroplotutils/__init__.py +++ b/hiroplotutils/__init__.py @@ -10,7 +10,7 @@ import sys import numpy as np from joblib import Parallel, delayed import logging - +import itertools P = ParamSpec("P") R = TypeVar("R") @@ -286,12 +286,16 @@ class PlotContainer: self._default_ratio: float = default_ratio self._pyplot_config = pyplot_config or (lambda: None) - def _save_fig(self, fig, size: str | tuple[str, float], num: int): + def _save_fig(self, fig, size: str | tuple[str, float], num: tuple[int, int]): horizontal = self._sizes[size[0] if isinstance(size, tuple) else size] vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio fig.set_size_inches(horizontal, vertical * horizontal) - save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir) + save_figure( + fig, + f"{num[0]:03}_{num[1]:03}_" + fig.get_label(), + directory=str(self._figdir), + ) def register(self, size: str | tuple[str, float], *args: dict): """Registers a plot to be executed.""" @@ -307,19 +311,27 @@ class PlotContainer: args = ({},) def decorator(f): - for keywords in args: + plots = [] + plot_index = len(self._plots) + 1 + for sub_index, keywords in enumerate(args): logging.info("Registered plot", f, keywords) - self._plots.append( + plots.append( ( - lambda i: [ + lambda: [ self._pyplot_config(), - self._save_fig(f(**keywords), size, i), - ], + self._save_fig( + f(**keywords), + size, + (len(self._plots) + 1, sub_index + 1), + ), + ] + and None, f, keywords, ) ) + self._plots.append(plots) return f return decorator @@ -330,6 +342,7 @@ class PlotContainer: if hasattr(sys, "ps1"): return + flattened = list(itertools.chain.from_iterable(self._plots)) import argparse parser = argparse.ArgumentParser() @@ -342,31 +355,36 @@ class PlotContainer: type=int, help="Only execute the plot with the given index.", default=None, - choices=range(len(self._plots)), + choices=[i + 1 for i in range(len(flattened))], nargs="+", ) cmd_args = parser.parse_args() if cmd_args.list: - format = "{:2d} {:<30} " - for i, plot in enumerate(self._plots): - print(format.format(i + 1, plot[1].__name__), end="") - print(" ".join(f"{k}={v}" for k, v in plot[2].items())) + print("hi") + format = "{:2d} [{:2d} {:2d}] {:<30} " + total = 1 + for i, plot_group in enumerate(self._plots): + for sub_index, plot in enumerate(plot_group): + print( + format.format(total, i + 1, sub_index + 1, plot[1].__name__), + end="", + ) + print(" ".join(f"{k}={v}" for k, v in plot[2].items())) + total += 1 return only = ( - range(len(self._plots)) + range(len(flattened)) if cmd_args.only is None else [i - 1 for i in cmd_args.only] ) - selected_plots = [(i + 1, self._plots[i]) for i in only] - if "n_jobs" not in kwargs or cmd_args.n_jobs != -1: kwargs["n_jobs"] = cmd_args.n_jobs if "backend" not in kwargs: kwargs["backend"] = "loky" - Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots) + Parallel(*args, **kwargs)(delayed(flattened[i][0])() for i in only)