refactor Plotter to use a common plot base

This commit is contained in:
2026-03-14 22:56:58 -07:00
parent 30fb6fa107
commit 69e8e88047
3 changed files with 477 additions and 198 deletions

View File

@@ -1,19 +1,134 @@
from typing import Self
"""
def plot_actual_output_residual(
self,
order_residuals: bool = False,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
Note: transform samples in dataloader definitions beforehand if you
want to change data
Parameters:
row_size:
col_size:
figure_kwargs:
subplot_kwargs:
gof_kwargs:
ndims = self.data_tuples[0][0].size(-1)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
fig, axes = self._create_subplots(
rows=len(self.dataloaders),
cols=1 if (norm_samples or combine_dims) else ndims,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, loader_label = data_tuple
if norm_samples:
actual = actual.norm(dim=1, keepdim=True)
output = output.norm(dim=1, keepdim=True)
for dim in range(ndims):
ax = axes_row[0 if combine_dims else dim]
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
residuals = dim_actual - dim_output
X, Y = dim_actual, residuals
if order_residuals:
X = range(1, residuals.size(0)+1)
Y = residuals[residuals.argsort()]
ax.scatter(
X, Y,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
m, b = self._lstsq_dim(dim_actual, dim_output)
ax.plot(
dim_actual,
m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
add_plot_context = (
norm_samples # dim=0 implicit b/c we break
or not combine_dims # add to every subplot across grid
or combine_dims and dim == ndims-1 # wait for last dim
)
# always exec plot logic if not combining, o/w just once
if add_plot_context:
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
ax.set_title(f"[{loader_label}] Prediction residuals")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
# transpose legend layout for more natural view
if norm_samples or not combine_dims:
ax.legend()
else:
handles, labels = self.get_transposed_handles_labels(ax)
ax.legend(handles, labels, ncols=2)
# break dimension loop if collapsed by norm
if norm_samples:
break
return fig, axes
"""
from functools import partial
from collections.abc import Callable, Generator
from collections.abc import Callable
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import Tensor
from numpy.typing import NDArray
from torch.utils.data import DataLoader
from trainlib.trainer import Trainer
from trainlib.estimator import EstimatorKwargs
from trainlib.utils.type import AxesArray, SubplotsKwargs
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
type ContextFn = Callable[[plt.Axes, str], None]
class Plotter[B, K: EstimatorKwargs]:
"""
TODOs:
- best fit lines for plots and residuals (compare to ideal lines in each
case)
- show val options across columns; preview how val is changing across
natural training, and what the best will look (so plot like uniform
intervals broken over the training epochs at 0, 50, 100, 150, ... and
highlight the best one, even if that's not actually the single best
epoch)
"""
def __init__(
self,
trainer: Trainer[..., K],
@@ -35,11 +150,11 @@ class Plotter[B, K: EstimatorKwargs]:
self._batch_outputs_fn = partial(
self.trainer.get_batch_outputs,
batch_estimator_map=self.batch_estimator_map
batch_estimator_map=batch_estimator_map
)
self._batch_metrics_fn = partial(
self.trainer.get_batch_metrics,
batch_estimator_map=self.batch_estimator_map
batch_estimator_map=batch_estimator_map
)
self._data_tuples = None
@@ -83,201 +198,369 @@ class Plotter[B, K: EstimatorKwargs]:
data_tuples.append((actual, output, label))
self._data_tuples = data_tuples
return self._data_tuples
def _default_figure_kwargs(self, rows: int, cols: int) -> dict:
return {
"sharex": True,
"figsize": (4*cols, 2*rows),
"constrained_layout": True,
}
def _default_subplot_kwargs(self) -> dict:
return {}
def _create_subplots(
def get_transposed_handles_labels(
self,
**figure_kwargs: SubplotsKwargs,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
ax: plt.Axes
) -> tuple[list, list]:
# transpose legend layout for more natural view
handles, labels = ax.get_legend_handles_labels()
handles = handles[::2] + handles[1::2]
labels = labels[::2] + labels[1::2]
rows, cols = len(self.dataloaders), 1
figure_kwargs = {
**self._default_figure_kwargs(rows, cols),
**figure_kwargs,
}
fig, axes = plt.subplots(
rows, cols,
squeeze=False,
**figure_kwargs
return handles, labels
def _lstsq_dim(
self,
dim_actual: Tensor,
dim_output: Tensor
) -> tuple[Tensor, Tensor]:
A = torch.stack(
[dim_actual, torch.ones_like(dim_actual)],
dim=1
)
m, b = torch.linalg.lstsq(A, dim_output).solution
return fig, axes
return m, b
def plot_actual_output_dim(
def _prepare_figure_kwargs(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
rows: int,
cols: int,
row_size: int | float = 2,
col_size: int | float = 4,
figure_kwargs: SubplotsKwargs | None = None,
) -> SubplotsKwargs:
"""
Wrapper like this works fine, but it's smelly: we *don't* want @wraps,
do this method doesn't actually have this signature at runtime (it has
the dec wrapper's sig). I think the cleaner thing is to just have
internal methods (_func) like the one below, and then the main method
entry just pass that internal method through to the skeleton
"""
figure_kwargs = figure_kwargs or {}
default_figure_kwargs = {
"sharex": True,
"figsize": (col_size*cols, row_size*rows),
}
figure_kwargs = {
**default_figure_kwargs,
**(figure_kwargs or {}),
}
return figure_kwargs
def _prepare_subplot_kwargs(
self,
subplot_kwargs: dict | None = None,
) -> dict:
"""
"""
default_subplot_kwargs = {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**default_subplot_kwargs,
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
return subplot_kwargs
def _prepare_gof_kwargs(
self,
gof_kwargs: dict | None = None,
) -> dict:
"""
"""
default_gof_kwargs = {
"alpha": 0.5
}
gof_kwargs = {
**default_gof_kwargs,
**(gof_kwargs or {}),
}
return gof_kwargs
def _create_subplots(
self,
rows: int,
cols: int,
row_size: int | float = 2,
col_size: int | float = 4,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs(
rows,
cols,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
fig, axes = plt.subplots(
rows,
cols,
squeeze=False,
**figure_kwargs,
) # ty:ignore[no-matching-overload]
return fig, axes
def _plot_base(
self,
subplot_fn: SubplotFn,
context_fn: ContextFn,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Note: transform samples in dataloader definitions beforehand if you
want to change data
Parameters:
row_size:
col_size:
figure_kwargs:
subplot_kwargs:
gof_kwargs:
"""
ndims = self.data_tuples[0][0].size(-1)
fig, axes = self._create_subplots(
rows=len(self.dataloaders),
cols=1 if (norm_samples or combine_dims) else ndims,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, loader_label = data_tuple
if norm_samples:
actual = actual.norm(dim=1, keepdim=True)
output = output.norm(dim=1, keepdim=True)
for dim in range(ndims):
ax = axes_row[0 if combine_dims else dim]
subplot_fn(ax, dim, actual, output)
add_plot_context = (
norm_samples # dim=0 implicit b/c we break
or not combine_dims # add to every subplot across grid
or combine_dims and dim == ndims-1 # wait for last dim
)
# always exec plot logic if not combining, o/w exec just once
if add_plot_context:
context_fn(ax, loader_label)
# transpose legend layout for more natural view
if norm_samples or not combine_dims:
ax.legend()
else:
handles, labels = self.get_transposed_handles_labels(ax)
ax.legend(handles, labels, ncols=2)
# break dimension loop if collapsed by norm
if norm_samples:
break
return fig, axes
def plot_actual_output(
self,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot residual distribution.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
ax.scatter(
dim_actual, dim_output,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
m, b = self._lstsq_dim(dim_actual, dim_output)
ax.plot(
dim_actual, m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
def context_fn(ax: plt.Axes, loader_label: str) -> None:
# plot perfect prediction reference line, y=x
ax.plot(
[0, 1], [0, 1],
transform=ax.transAxes,
c="black",
alpha=0.2
alpha=0.2,
)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)")
ax.set_title(
f"[{loader_label}] True labels vs Predictions"
)
ax.set_xlabel("actual")
ax.set_ylabel("output")
ax.legend()
return fig, axes
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def plot_actual_output_residual_dim(
def plot_actual_output_residual(
self,
figure_kwargs: dict | None = None,
row_size: int | float = 2,
col_size: int | float = 4,
order_residuals: bool = False,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot prediction residuals.
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
residuals = dim_actual - dim_output
X, Y = dim_actual, residuals
if order_residuals:
X = range(1, residuals.size(0)+1)
Y = residuals[residuals.argsort()]
ax.scatter(
X, Y,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
if not order_residuals:
m, b = self._lstsq_dim(dim_actual, residuals)
ax.plot(
dim_actual, m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
def context_fn(ax: plt.Axes, loader_label: str) -> None:
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
actual[:, dim] - output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] Residuals (dim-wise)")
ax.set_title(f"[{loader_label}] Prediction residuals")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.legend()
return fig, axes
def plot_actual_output_ordered_residual_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
actual[:, dim] - output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] Residuals (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.legend()
return fig, axes
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def plot_actual_output_residual_dist(
self,
figure_kwargs: dict | None = None,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot residual distribution.
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
N = actual.size(0)
for dim in range(actual.size(-1)):
residuals = actual[:, dim] - output[:, dim]
dim_actual = actual[:, dim]
dim_output = output[:, dim]
_, _, patches = ax.hist(
residuals.abs(),
bins=int(np.sqrt(N)),
density=True,
alpha=0.2,
label=f"$d_{dim}$",
**subplot_kwargs
)
N = dim_actual.size(0)
residuals = dim_actual - dim_output
# grab color used for hist and mirror in the v-line
color = patches[0].get_facecolor()
_, _, patches = ax.hist(
residuals.abs(),
bins=int(np.sqrt(N)),
density=True,
alpha=0.3,
color=dim_color,
label=group_name,
**subplot_kwargs
)
mu = residuals.abs().mean().item()
ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$")
mu = residuals.abs().mean().item()
ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$")
ax.set_title(f"[{label}] Residual distribution (dim-wise)")
def context_fn(ax: plt.Axes, loader_label: str) -> None:
ax.set_title(f"[{loader_label}] Residual distribution")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.set_ylabel("residual (density)")
# transpose legend layout for more natural view
handles, labels = ax.get_legend_handles_labels()
handles = handles[::2] + handles[1::2]
labels = labels[::2] + labels[1::2]
ax.legend(handles, labels, ncols=2)
return fig, axes
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)

View File

@@ -1,46 +1,49 @@
text.usetex : False
mathtext.default : regular
text.usetex : False
mathtext.default : regular
font.family : sans-serif
font.sans-serif : DejaVu Sans
font.serif : DejaVu Serif
font.cursive : DejaVu Sans
mathtext.fontset : dejavuserif
font.size : 9
figure.titlesize : 9
legend.fontsize : 9
axes.titlesize : 9
axes.labelsize : 9
xtick.labelsize : 9
ytick.labelsize : 9
# testing to prevent component overlap/clipping
figure.constrained_layout.use : True
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb'])
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
font.family : sans-serif
font.sans-serif : DejaVu Sans
font.serif : DejaVu Serif
font.cursive : DejaVu Sans
mathtext.fontset : dejavuserif
font.size : 9
figure.titlesize : 9
legend.fontsize : 9
axes.titlesize : 9
axes.labelsize : 9
xtick.labelsize : 9
ytick.labelsize : 9
image.interpolation : nearest
image.resample : False
image.composite_image : True
# monobiome -d 0.45 -l 22
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
axes.spines.left : True
axes.spines.bottom : True
axes.spines.top : False
axes.spines.right : False
image.interpolation : nearest
image.resample : False
image.composite_image : True
axes.linewidth : 1
xtick.major.width : 1
xtick.minor.width : 1
ytick.major.width : 1
ytick.minor.width : 1
axes.spines.left : True
axes.spines.bottom : True
axes.spines.top : False
axes.spines.right : False
lines.linewidth : 1
lines.markersize : 1
axes.linewidth : 1
xtick.major.width : 1
xtick.minor.width : 1
ytick.major.width : 1
ytick.minor.width : 1
savefig.dpi : 300
savefig.format : svg
savefig.bbox : tight
savefig.pad_inches : 0.1
lines.linewidth : 1
lines.markersize : 1
svg.image_inline : True
svg.fonttype : none
savefig.dpi : 300
savefig.format : svg
savefig.bbox : tight
savefig.pad_inches : 0.1
legend.frameon : False
svg.image_inline : True
svg.fonttype : none
legend.frameon : False

View File

@@ -1,5 +1,5 @@
from typing import Any, TypedDict
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Sequence
import numpy as np
from torch import Tensor
@@ -59,18 +59,11 @@ class SubplotsKwargs(TypedDict, total=False):
sharex: bool | str
sharey: bool | str
squeeze: bool
width_ratios: list[float]
height_ratios: list[float]
subplot_kw: dict
gridspec_kw: dict
width_ratios: Sequence[float]
height_ratios: Sequence[float]
subplot_kw: dict[str, ...]
gridspec_kw: dict[str, ...]
figsize: tuple[float, float]
dpi: float
layout: str
sharex: bool | Literal["none", "all", "row", "col"] = False,
sharey: bool | Literal["none", "all", "row", "col"] = False,
squeeze: bool = True,
width_ratios: Sequence[float] | None = None,
height_ratios: Sequence[float] | None = None,
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,
constrained_layout: bool