mirror of
https://github.com/vale981/hiroplotutils
synced 2025-03-04 17:11:38 -05:00
fix plot numbering
This commit is contained in:
parent
5a6231e2e6
commit
14969bcef2
1 changed files with 35 additions and 17 deletions
|
@ -10,7 +10,7 @@ import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
import logging
|
import logging
|
||||||
|
import itertools
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
@ -286,12 +286,16 @@ class PlotContainer:
|
||||||
self._default_ratio: float = default_ratio
|
self._default_ratio: float = default_ratio
|
||||||
self._pyplot_config = pyplot_config or (lambda: None)
|
self._pyplot_config = pyplot_config or (lambda: None)
|
||||||
|
|
||||||
def _save_fig(self, fig, size: str | tuple[str, float], num: int):
|
def _save_fig(self, fig, size: str | tuple[str, float], num: tuple[int, int]):
|
||||||
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
|
horizontal = self._sizes[size[0] if isinstance(size, tuple) else size]
|
||||||
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
|
vertical: float = size[1] if isinstance(size, tuple) else self._default_ratio
|
||||||
|
|
||||||
fig.set_size_inches(horizontal, vertical * horizontal)
|
fig.set_size_inches(horizontal, vertical * horizontal)
|
||||||
save_figure(fig, f"{num:03}_" + fig.get_label(), directory=self._figdir)
|
save_figure(
|
||||||
|
fig,
|
||||||
|
f"{num[0]:03}_{num[1]:03}_" + fig.get_label(),
|
||||||
|
directory=str(self._figdir),
|
||||||
|
)
|
||||||
|
|
||||||
def register(self, size: str | tuple[str, float], *args: dict):
|
def register(self, size: str | tuple[str, float], *args: dict):
|
||||||
"""Registers a plot to be executed."""
|
"""Registers a plot to be executed."""
|
||||||
|
@ -307,19 +311,27 @@ class PlotContainer:
|
||||||
args = ({},)
|
args = ({},)
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
for keywords in args:
|
plots = []
|
||||||
|
plot_index = len(self._plots) + 1
|
||||||
|
for sub_index, keywords in enumerate(args):
|
||||||
logging.info("Registered plot", f, keywords)
|
logging.info("Registered plot", f, keywords)
|
||||||
self._plots.append(
|
plots.append(
|
||||||
(
|
(
|
||||||
lambda i: [
|
lambda: [
|
||||||
self._pyplot_config(),
|
self._pyplot_config(),
|
||||||
self._save_fig(f(**keywords), size, i),
|
self._save_fig(
|
||||||
],
|
f(**keywords),
|
||||||
|
size,
|
||||||
|
(len(self._plots) + 1, sub_index + 1),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
and None,
|
||||||
f,
|
f,
|
||||||
keywords,
|
keywords,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._plots.append(plots)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
@ -330,6 +342,7 @@ class PlotContainer:
|
||||||
if hasattr(sys, "ps1"):
|
if hasattr(sys, "ps1"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
flattened = list(itertools.chain.from_iterable(self._plots))
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -342,31 +355,36 @@ class PlotContainer:
|
||||||
type=int,
|
type=int,
|
||||||
help="Only execute the plot with the given index.",
|
help="Only execute the plot with the given index.",
|
||||||
default=None,
|
default=None,
|
||||||
choices=range(len(self._plots)),
|
choices=[i + 1 for i in range(len(flattened))],
|
||||||
nargs="+",
|
nargs="+",
|
||||||
)
|
)
|
||||||
cmd_args = parser.parse_args()
|
cmd_args = parser.parse_args()
|
||||||
|
|
||||||
if cmd_args.list:
|
if cmd_args.list:
|
||||||
format = "{:2d} {:<30} "
|
print("hi")
|
||||||
for i, plot in enumerate(self._plots):
|
format = "{:2d} [{:2d} {:2d}] {:<30} "
|
||||||
print(format.format(i + 1, plot[1].__name__), end="")
|
total = 1
|
||||||
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
|
for i, plot_group in enumerate(self._plots):
|
||||||
|
for sub_index, plot in enumerate(plot_group):
|
||||||
|
print(
|
||||||
|
format.format(total, i + 1, sub_index + 1, plot[1].__name__),
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
print(" ".join(f"{k}={v}" for k, v in plot[2].items()))
|
||||||
|
total += 1
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
only = (
|
only = (
|
||||||
range(len(self._plots))
|
range(len(flattened))
|
||||||
if cmd_args.only is None
|
if cmd_args.only is None
|
||||||
else [i - 1 for i in cmd_args.only]
|
else [i - 1 for i in cmd_args.only]
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_plots = [(i + 1, self._plots[i]) for i in only]
|
|
||||||
|
|
||||||
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
|
if "n_jobs" not in kwargs or cmd_args.n_jobs != -1:
|
||||||
kwargs["n_jobs"] = cmd_args.n_jobs
|
kwargs["n_jobs"] = cmd_args.n_jobs
|
||||||
|
|
||||||
if "backend" not in kwargs:
|
if "backend" not in kwargs:
|
||||||
kwargs["backend"] = "loky"
|
kwargs["backend"] = "loky"
|
||||||
|
|
||||||
Parallel(*args, **kwargs)(delayed(plot[0])(i) for i, plot in selected_plots)
|
Parallel(*args, **kwargs)(delayed(flattened[i][0])() for i in only)
|
||||||
|
|
Loading…
Add table
Reference in a new issue