refactor Plotter to use a common plot base

This commit is contained in:
2026-03-14 22:56:58 -07:00
parent 30fb6fa107
commit 69e8e88047
3 changed files with 477 additions and 198 deletions

View File

@@ -1,19 +1,134 @@
from typing import Self """
def plot_actual_output_residual(
self,
order_residuals: bool = False,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
Note: transform samples in dataloader definitions beforehand if you
want to change data
Parameters:
row_size:
col_size:
figure_kwargs:
subplot_kwargs:
gof_kwargs:
ndims = self.data_tuples[0][0].size(-1)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
fig, axes = self._create_subplots(
rows=len(self.dataloaders),
cols=1 if (norm_samples or combine_dims) else ndims,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, loader_label = data_tuple
if norm_samples:
actual = actual.norm(dim=1, keepdim=True)
output = output.norm(dim=1, keepdim=True)
for dim in range(ndims):
ax = axes_row[0 if combine_dims else dim]
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
residuals = dim_actual - dim_output
X, Y = dim_actual, residuals
if order_residuals:
X = range(1, residuals.size(0)+1)
Y = residuals[residuals.argsort()]
ax.scatter(
X, Y,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
m, b = self._lstsq_dim(dim_actual, dim_output)
ax.plot(
dim_actual,
m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
add_plot_context = (
norm_samples # dim=0 implicit b/c we break
or not combine_dims # add to every subplot across grid
or combine_dims and dim == ndims-1 # wait for last dim
)
# always exec plot logic if not combining, o/w just once
if add_plot_context:
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
ax.set_title(f"[{loader_label}] Prediction residuals")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
# transpose legend layout for more natural view
if norm_samples or not combine_dims:
ax.legend()
else:
handles, labels = self.get_transposed_handles_labels(ax)
ax.legend(handles, labels, ncols=2)
# break dimension loop if collapsed by norm
if norm_samples:
break
return fig, axes
"""
from functools import partial from functools import partial
from collections.abc import Callable, Generator from collections.abc import Callable
import numpy as np import numpy as np
import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch import Tensor from torch import Tensor
from numpy.typing import NDArray
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainlib.trainer import Trainer from trainlib.trainer import Trainer
from trainlib.estimator import EstimatorKwargs from trainlib.estimator import EstimatorKwargs
from trainlib.utils.type import AxesArray, SubplotsKwargs from trainlib.utils.type import AxesArray, SubplotsKwargs
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
type ContextFn = Callable[[plt.Axes, str], None]
class Plotter[B, K: EstimatorKwargs]: class Plotter[B, K: EstimatorKwargs]:
"""
TODOs:
- best fit lines for plots and residuals (compare to ideal lines in each
case)
- show val options across columns; preview how val is changing across
natural training, and what the best will look (so plot like uniform
intervals broken over the training epochs at 0, 50, 100, 150, ... and
highlight the best one, even if that's not actually the single best
epoch)
"""
def __init__( def __init__(
self, self,
trainer: Trainer[..., K], trainer: Trainer[..., K],
@@ -35,11 +150,11 @@ class Plotter[B, K: EstimatorKwargs]:
self._batch_outputs_fn = partial( self._batch_outputs_fn = partial(
self.trainer.get_batch_outputs, self.trainer.get_batch_outputs,
batch_estimator_map=self.batch_estimator_map batch_estimator_map=batch_estimator_map
) )
self._batch_metrics_fn = partial( self._batch_metrics_fn = partial(
self.trainer.get_batch_metrics, self.trainer.get_batch_metrics,
batch_estimator_map=self.batch_estimator_map batch_estimator_map=batch_estimator_map
) )
self._data_tuples = None self._data_tuples = None
@@ -83,201 +198,369 @@ class Plotter[B, K: EstimatorKwargs]:
data_tuples.append((actual, output, label)) data_tuples.append((actual, output, label))
self._data_tuples = data_tuples self._data_tuples = data_tuples
return self._data_tuples return self._data_tuples
def _default_figure_kwargs(self, rows: int, cols: int) -> dict: def get_transposed_handles_labels(
return {
"sharex": True,
"figsize": (4*cols, 2*rows),
"constrained_layout": True,
}
def _default_subplot_kwargs(self) -> dict:
return {}
def _create_subplots(
self, self,
**figure_kwargs: SubplotsKwargs, ax: plt.Axes
) -> tuple[plt.Figure, AxesArray]: ) -> tuple[list, list]:
"""
"""
rows, cols = len(self.dataloaders), 1
figure_kwargs = {
**self._default_figure_kwargs(rows, cols),
**figure_kwargs,
}
fig, axes = plt.subplots(
rows, cols,
squeeze=False,
**figure_kwargs
)
return fig, axes
def plot_actual_output_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Wrapper like this works fine, but it's smelly: we *don't* want @wraps,
do this method doesn't actually have this signature at runtime (it has
the dec wrapper's sig). I think the cleaner thing is to just have
internal methods (_func) like the one below, and then the main method
entry just pass that internal method through to the skeleton
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
ax.plot(
[0, 1], [0, 1],
transform=ax.transAxes,
c="black",
alpha=0.2
)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("output")
ax.legend()
return fig, axes
def plot_actual_output_residual_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
actual[:, dim] - output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] Residuals (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.legend()
return fig, axes
def plot_actual_output_ordered_residual_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
actual[:, dim] - output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] Residuals (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.legend()
return fig, axes
def plot_actual_output_residual_dist(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
N = actual.size(0)
for dim in range(actual.size(-1)):
residuals = actual[:, dim] - output[:, dim]
_, _, patches = ax.hist(
residuals.abs(),
bins=int(np.sqrt(N)),
density=True,
alpha=0.2,
label=f"$d_{dim}$",
**subplot_kwargs
)
# grab color used for hist and mirror in the v-line
color = patches[0].get_facecolor()
mu = residuals.abs().mean().item()
ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$")
ax.set_title(f"[{label}] Residual distribution (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
# transpose legend layout for more natural view # transpose legend layout for more natural view
handles, labels = ax.get_legend_handles_labels() handles, labels = ax.get_legend_handles_labels()
handles = handles[::2] + handles[1::2] handles = handles[::2] + handles[1::2]
labels = labels[::2] + labels[1::2] labels = labels[::2] + labels[1::2]
ax.legend(handles, labels, ncols=2) return handles, labels
def _lstsq_dim(
self,
dim_actual: Tensor,
dim_output: Tensor
) -> tuple[Tensor, Tensor]:
A = torch.stack(
[dim_actual, torch.ones_like(dim_actual)],
dim=1
)
m, b = torch.linalg.lstsq(A, dim_output).solution
return m, b
def _prepare_figure_kwargs(
self,
rows: int,
cols: int,
row_size: int | float = 2,
col_size: int | float = 4,
figure_kwargs: SubplotsKwargs | None = None,
) -> SubplotsKwargs:
"""
"""
default_figure_kwargs = {
"sharex": True,
"figsize": (col_size*cols, row_size*rows),
}
figure_kwargs = {
**default_figure_kwargs,
**(figure_kwargs or {}),
}
return figure_kwargs
def _prepare_subplot_kwargs(
self,
subplot_kwargs: dict | None = None,
) -> dict:
"""
"""
default_subplot_kwargs = {}
subplot_kwargs = {
**default_subplot_kwargs,
**(subplot_kwargs or {}),
}
return subplot_kwargs
def _prepare_gof_kwargs(
self,
gof_kwargs: dict | None = None,
) -> dict:
"""
"""
default_gof_kwargs = {
"alpha": 0.5
}
gof_kwargs = {
**default_gof_kwargs,
**(gof_kwargs or {}),
}
return gof_kwargs
def _create_subplots(
self,
rows: int,
cols: int,
row_size: int | float = 2,
col_size: int | float = 4,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs(
rows,
cols,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
fig, axes = plt.subplots(
rows,
cols,
squeeze=False,
**figure_kwargs,
) # ty:ignore[no-matching-overload]
return fig, axes return fig, axes
def _plot_base(
self,
subplot_fn: SubplotFn,
context_fn: ContextFn,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Note: transform samples in dataloader definitions beforehand if you
want to change data
Parameters:
row_size:
col_size:
figure_kwargs:
subplot_kwargs:
gof_kwargs:
"""
ndims = self.data_tuples[0][0].size(-1)
fig, axes = self._create_subplots(
rows=len(self.dataloaders),
cols=1 if (norm_samples or combine_dims) else ndims,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, loader_label = data_tuple
if norm_samples:
actual = actual.norm(dim=1, keepdim=True)
output = output.norm(dim=1, keepdim=True)
for dim in range(ndims):
ax = axes_row[0 if combine_dims else dim]
subplot_fn(ax, dim, actual, output)
add_plot_context = (
norm_samples # dim=0 implicit b/c we break
or not combine_dims # add to every subplot across grid
or combine_dims and dim == ndims-1 # wait for last dim
)
# always exec plot logic if not combining, o/w exec just once
if add_plot_context:
context_fn(ax, loader_label)
# transpose legend layout for more natural view
if norm_samples or not combine_dims:
ax.legend()
else:
handles, labels = self.get_transposed_handles_labels(ax)
ax.legend(handles, labels, ncols=2)
# break dimension loop if collapsed by norm
if norm_samples:
break
return fig, axes
def plot_actual_output(
self,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot residual distribution.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
ax.scatter(
dim_actual, dim_output,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
m, b = self._lstsq_dim(dim_actual, dim_output)
ax.plot(
dim_actual, m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
def context_fn(ax: plt.Axes, loader_label: str) -> None:
# plot perfect prediction reference line, y=x
ax.plot(
[0, 1], [0, 1],
transform=ax.transAxes,
c="black",
alpha=0.2,
)
ax.set_title(
f"[{loader_label}] True labels vs Predictions"
)
ax.set_xlabel("actual")
ax.set_ylabel("output")
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def plot_actual_output_residual(
self,
row_size: int | float = 2,
col_size: int | float = 4,
order_residuals: bool = False,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot prediction residuals.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
residuals = dim_actual - dim_output
X, Y = dim_actual, residuals
if order_residuals:
X = range(1, residuals.size(0)+1)
Y = residuals[residuals.argsort()]
ax.scatter(
X, Y,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
if not order_residuals:
m, b = self._lstsq_dim(dim_actual, residuals)
ax.plot(
dim_actual, m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
def context_fn(ax: plt.Axes, loader_label: str) -> None:
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
ax.set_title(f"[{loader_label}] Prediction residuals")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def plot_actual_output_residual_dist(
self,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot residual distribution.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
N = dim_actual.size(0)
residuals = dim_actual - dim_output
_, _, patches = ax.hist(
residuals.abs(),
bins=int(np.sqrt(N)),
density=True,
alpha=0.3,
color=dim_color,
label=group_name,
**subplot_kwargs
)
mu = residuals.abs().mean().item()
ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$")
def context_fn(ax: plt.Axes, loader_label: str) -> None:
ax.set_title(f"[{loader_label}] Residual distribution")
ax.set_xlabel("actual")
ax.set_ylabel("residual (density)")
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)

View File

@@ -1,6 +1,9 @@
text.usetex : False text.usetex : False
mathtext.default : regular mathtext.default : regular
# testing to prevent component overlap/clipping
figure.constrained_layout.use : True
font.family : sans-serif font.family : sans-serif
font.sans-serif : DejaVu Sans font.sans-serif : DejaVu Sans
font.serif : DejaVu Serif font.serif : DejaVu Serif
@@ -14,7 +17,7 @@ axes.labelsize : 9
xtick.labelsize : 9 xtick.labelsize : 9
ytick.labelsize : 9 ytick.labelsize : 9
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb']) # monobiome -d 0.45 -l 22
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a']) axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
image.interpolation : nearest image.interpolation : nearest

View File

@@ -1,5 +1,5 @@
from typing import Any, TypedDict from typing import Any, TypedDict
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable, Sequence
import numpy as np import numpy as np
from torch import Tensor from torch import Tensor
@@ -59,18 +59,11 @@ class SubplotsKwargs(TypedDict, total=False):
sharex: bool | str sharex: bool | str
sharey: bool | str sharey: bool | str
squeeze: bool squeeze: bool
width_ratios: list[float] width_ratios: Sequence[float]
height_ratios: list[float] height_ratios: Sequence[float]
subplot_kw: dict subplot_kw: dict[str, ...]
gridspec_kw: dict gridspec_kw: dict[str, ...]
figsize: tuple[float, float] figsize: tuple[float, float]
dpi: float dpi: float
layout: str layout: str
constrained_layout: bool
sharex: bool | Literal["none", "all", "row", "col"] = False,
sharey: bool | Literal["none", "all", "row", "col"] = False,
squeeze: bool = True,
width_ratios: Sequence[float] | None = None,
height_ratios: Sequence[float] | None = None,
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,